From 76684141a5d059be71cbe23dc2f0ed552213ba2d Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Sun, 22 Mar 2026 02:03:00 +0900 Subject: [PATCH 001/249] ruby : fix dangling pointers, memory leak, and SEGV on parallel transcription (#3715) * Prevent dangling pointers * Use proper free function * Free callback containers * Set default log callback when nil is passed to log_set * Raise error if callbacks set when parallel transcription * Bump version to 1.3.7 * Make tests follow spec change * Add note on parallel transcription and callbacks * Update signature of Whisper.log_set [skip ci] --- bindings/ruby/README.md | 2 + bindings/ruby/ext/ruby_whisper.c | 16 +++- bindings/ruby/ext/ruby_whisper.h | 1 + bindings/ruby/ext/ruby_whisper_context.c | 6 +- bindings/ruby/ext/ruby_whisper_params.c | 79 +++++++++++++++++-- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 4 +- bindings/ruby/sig/whisper.rbs | 8 +- bindings/ruby/test/test_params.rb | 2 + bindings/ruby/test/test_whisper.rb | 29 +++++-- bindings/ruby/whispercpp.gemspec | 2 +- 10 files changed, 127 insertions(+), 22 deletions(-) diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index c6280a6926a..41e7b330d58 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -202,6 +202,8 @@ whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors) Note that transcription occasionally might be low accuracy when it works in parallel. +If n_processors is greater than 1, you cannot set any callbacks including new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, and log_callback set by Whisper.log_set. + ### Segments ### Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ba71d4ba594..5f1917ee805 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -112,6 +112,10 @@ ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * return; } VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (NIL_P(log_callback)) { + return; + } + VALUE udata = rb_iv_get(mWhisper, "user_data"); rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); } @@ -129,10 +133,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d rb_iv_set(self, "log_callback", log_callback); rb_iv_set(self, "user_data", user_data); - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); + if (!NIL_P(log_callback)) { + VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); + rb_define_finalizer(log_callback, finalize_log_callback); + } - whisper_log_set(ruby_whisper_log_callback, NULL); + if (NIL_P(log_callback)) { + whisper_log_set(NULL, NULL); + } else { + whisper_log_set(ruby_whisper_log_callback, NULL); + } return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8dfd103c17a..6b0b4df7214 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -2,6 +2,7 @@ #define RUBY_WHISPER_H #include +#include #include #include "whisper.h" diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index c39d43bd76c..6e38ead6321 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -22,7 +22,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); ID transcribe_option_names[1]; @@ -436,7 +436,7 @@ full_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context); + prepare_transcription(rwp, args->context, 1); int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); return INT2NUM(result); @@ -487,7 +487,7 @@ full_parallel_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context); + prepare_transcription(rwp, args->context, args->n_processors); int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); return INT2NUM(result); diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 61eb1733676..3e5dca9c1e1 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -29,6 +29,7 @@ extern VALUE cParams; extern VALUE cVADParams; +extern VALUE mWhisper; extern ID id_call; @@ -186,6 +187,35 @@ static bool abort_callback(void * user_data) { return false; } +static void +check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) +{ + if (n_processors == 1) { + return; + } + + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); + } + + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (!NIL_P(log_callback)) { + rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); + } +} + static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { rwp->new_segment_callback_container->context = context; @@ -219,9 +249,13 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } +/* + TODO: Set abort callback to trap SIGINT and SIGTERM +*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) { + check_thread_safety(rwp, context, n_processors); register_callbacks(rwp, context); set_vad_params(rwp); } @@ -240,6 +274,20 @@ rb_whisper_params_mark(void *p) void ruby_whisper_params_free(ruby_whisper_params *rwp) { + if (rwp->params.language) { + ruby_xfree((void *)rwp->params.language); + } + if (rwp->params.initial_prompt) { + ruby_xfree((void *)rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path) { + ruby_xfree((void *)rwp->params.vad_model_path); + } + + xfree(rwp->new_segment_callback_container); + xfree(rwp->progress_callback_container); + xfree(rwp->encoder_begin_callback_container); + xfree(rwp->abort_callback_container); } void @@ -248,7 +296,7 @@ rb_whisper_params_free(void *p) ruby_whisper_params *rwp = (ruby_whisper_params *)p; // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); - free(rwp); + xfree(rwp); } static size_t @@ -276,6 +324,15 @@ ruby_whisper_params_allocate(VALUE klass) ruby_whisper_params *rwp; VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + if (rwp->params.language != NULL) { + rwp->params.language = ruby_strdup(rwp->params.language); + } + if (rwp->params.initial_prompt != NULL) { + rwp->params.initial_prompt = ruby_strdup(rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path != NULL) { + rwp->params.vad_model_path = ruby_strdup(rwp->params.vad_model_path); + } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); @@ -296,10 +353,12 @@ ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.language); + rwp->params.language = NULL; if (value == Qfalse || value == Qnil) { - rwp->params.language = "auto"; + rwp->params.language = ruby_strdup("auto"); } else { - rwp->params.language = StringValueCStr(value); + rwp->params.language = ruby_strdup(StringValueCStr(value)); } return value; } @@ -608,7 +667,13 @@ ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); - rwp->params.initial_prompt = StringValueCStr(value); + ruby_xfree((void *)rwp->params.initial_prompt); + rwp->params.initial_prompt = NULL; + if (NIL_P(value)) { + rwp->params.initial_prompt = NULL; + } else { + rwp->params.initial_prompt = ruby_strdup(StringValueCStr(value)); + } return value; } /* @@ -1103,12 +1168,14 @@ ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.vad_model_path); + rwp->params.vad_model_path = NULL; if (NIL_P(value)) { rwp->params.vad_model_path = NULL; return value; } VALUE path = ruby_whisper_normalize_model_path(value); - rwp->params.vad_model_path = StringValueCStr(path); + rwp->params.vad_model_path = ruby_strdup(StringValueCStr(path)); return value; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index c00fbcd1def..3d00566009a 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,7 +16,7 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self); +prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); /* * transcribe a single file @@ -73,7 +73,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { // rwp->params.encoder_begin_callback_user_data = &is_aborted; // } - prepare_transcription(rwp, &self); + prepare_transcription(rwp, &self, n_processors); if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { fprintf(stderr, "failed to process audio\n"); diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 9ade451c6b2..3c59661975b 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -37,7 +37,7 @@ module Whisper def self.lang_id: (string name) -> Integer def self.lang_str: (Integer id) -> String def self.lang_str_full: (Integer id) -> String - def self.log_set: (log_callback, Object? user_data) -> log_callback + def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String class Context @@ -52,6 +52,9 @@ module Whisper # puts text # end # + # If n_processors is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set def transcribe: (path, Params, ?n_processors: Integer) -> self | (path, Params, ?n_processors: Integer) { (String) -> void } -> self @@ -129,6 +132,9 @@ module Whisper # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # + # If n_processors is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index 094dba6f48e..ff5c28e9043 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -46,6 +46,8 @@ def setup def test_language @params.language = "en" assert_equal @params.language, "en" + GC.compact + assert_equal @params.language, "en" @params.language = "auto" assert_equal @params.language, "auto" end diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 29071210072..f7e25239d5d 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -43,9 +43,20 @@ def test_transcribe_n_processors @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new - @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| - assert_match(/what you can do for your country/i, text) - } + without_log_callback do + @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| + assert_match(/what you can do for your country/i, text) + } + end + end + + private + + def without_log_callback + Whisper.log_set nil, nil + yield + ensure + Whisper.log_set ->(level, buffer, user_data) {}, nil end sub_test_case "After transcription" do @@ -229,7 +240,9 @@ def test_full_with_memroy_view_gc def test_full_parallel nprocessors = 2 - @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -240,7 +253,9 @@ def test_full_parallel def test_full_parallel_with_memory_view nprocessors = 2 samples = JFKReader.new(AUDIO) - @whisper.full_parallel(@params, samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -259,7 +274,9 @@ def test_full_parallel_without_length_and_n_processors def test_full_parallel_without_length nprocessors = 2 - @whisper.full_parallel(@params, @samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 88b94e7eb8a..2d952222f29 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.6' + s.version = '1.3.7' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] From 1335dfa785af56c55bf510275f86f795db2fe474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 15 Mar 2026 19:10:15 +0100 Subject: [PATCH 002/249] sycl : fix for untransposed GDA recurrent state (llama/20583) --- ggml/src/ggml-sycl/gated_delta_net.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 8c76afbd571..648455c134b 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -55,7 +55,7 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[i * S_v + col]; + s_shard[r] = curr_state[col * S_v + i]; } for (int t = 0; t < n_tokens; t++) { @@ -137,7 +137,7 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - state[i * S_v + col] = s_shard[r]; + state[col * S_v + i] = s_shard[r]; } } From dae7781052d858a38d5d57eb2b252364d6e2c6d0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 16 Mar 2026 11:41:45 +0800 Subject: [PATCH 003/249] CUDA: GDN hide memory latency (llama/20537) --- ggml/src/ggml-cuda/gated_delta_net.cu | 32 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 1ce6d5f31b5..6b44bec7317 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,7 +1,8 @@ #include "gated_delta_net.cuh" template -__global__ void gated_delta_net_cuda(const float * q, +__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) +gated_delta_net_cuda(const float * q, const float * k, const float * v, const float * g, @@ -38,7 +39,7 @@ __global__ void gated_delta_net_cuda(const float * q, const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; state += state_offset; - curr_state += state_offset; + curr_state += state_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -46,10 +47,11 @@ __global__ void gated_delta_net_cuda(const float * q, constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } for (int t = 0; t < n_tokens; t++) { @@ -63,6 +65,16 @@ __global__ void gated_delta_net_cuda(const float * q, const float beta_val = *beta_t; + // Cache k and q in registers + float k_reg[rows_per_lane]; + float q_reg[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + k_reg[r] = k_t[i]; + q_reg[r] = q_t[i]; + } + if constexpr (!KDA) { const float g_val = expf(*g_t); @@ -70,8 +82,7 @@ __global__ void gated_delta_net_cuda(const float * q, float kv_shard = 0.0f; #pragma unroll for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - kv_shard += s_shard[r] * k_t[i]; + kv_shard += s_shard[r] * k_reg[r]; } float kv_col = warp_reduce_sum(kv_shard); @@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q, float attn_partial = 0.0f; #pragma unroll for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; - attn_partial += s_shard[r] * q_t[i]; + s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; } float attn_col = warp_reduce_sum(attn_partial); @@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; + kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r]; } float kv_col = warp_reduce_sum(kv_shard); @@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; - attn_partial += s_shard[r] * q_t[i]; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; } float attn_col = warp_reduce_sum(attn_partial); From 724ea71cf97e4887809093c7fc75b7bfe34506f4 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 16 Mar 2026 10:45:49 +0100 Subject: [PATCH 004/249] vulkan: fix flash attention dot product precision (llama/20589) --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec48f5b1152..11b7dce8578 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -245,7 +245,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); } } } @@ -270,7 +270,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); } } } From 9232af59ba3cf8ed9050d4a7229fd8a30e9eeeb8 Mon Sep 17 00:00:00 2001 From: Martin Klacer Date: Mon, 16 Mar 2026 19:25:54 +0000 Subject: [PATCH 005/249] kleidiai: add data type check to get_tensor_traits (llama/20639) * kleidiai: add data type check to get_tensor_traits * Added check for F16 data type into get_tensor_traits path with input data not in ggml_backend_cpu_kleidiai_buffer_type format (unsupported for Q4/8) Signed-off-by: Martin Klacer Change-Id: I9aca4b9b8d669d35db6f1dbcc4e080b1919b1de7 * updated ggml/src/ggml-cpu/kleidiai/kleidiai.cpp updated kleidiai.cpp file as per suggestion Co-authored-by: Georgi Gerganov --------- Signed-off-by: Martin Klacer Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 9bcc18d442c..7a5924944a8 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1473,10 +1473,12 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } else { + if (op->src[0]->type != GGML_TYPE_F16) { + return nullptr; + } std::array kernel_chain; const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); - const bool has_kernel = slot_total > 0; - if (has_kernel && op->src[1]->ne[1] > 1) { + if (slot_total > 0 && op->src[1]->ne[1] > 1) { if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { return nullptr; From 64942511978ccb23975c17b44513fa7fe858b8a2 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Tue, 17 Mar 2026 10:01:52 +0800 Subject: [PATCH 006/249] ehance UPSCALE to support all UT cases (llama/20637) * [SYCL] ehance UPSCALE to support more cases * rm test case result of SYCL1 --- ggml/src/ggml-sycl/backend.hpp | 4 +- ggml/src/ggml-sycl/element_wise.cpp | 89 ------ ggml/src/ggml-sycl/element_wise.hpp | 2 - ggml/src/ggml-sycl/ggml-sycl.cpp | 4 +- ggml/src/ggml-sycl/upscale.cpp | 410 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/upscale.hpp | 9 + 6 files changed, 423 insertions(+), 95 deletions(-) create mode 100644 ggml/src/ggml-sycl/upscale.cpp create mode 100644 ggml/src/ggml-sycl/upscale.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b30b7f2beb7..a526d8e58bc 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -24,6 +24,7 @@ #include "dmmv.hpp" #include "element_wise.hpp" #include "fattn.hpp" +#include "gated_delta_net.hpp" #include "gla.hpp" #include "im2col.hpp" #include "mmq.hpp" @@ -31,6 +32,7 @@ #include "norm.hpp" #include "outprod.hpp" #include "pad.hpp" +#include "pad_reflect_1d.hpp" #include "quantize.hpp" #include "quants.hpp" #include "roll.hpp" @@ -39,8 +41,8 @@ #include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" +#include "upscale.hpp" #include "wkv.hpp" -#include "pad_reflect_1d.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index acd51bf45b2..ec0247528c4 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -294,30 +294,6 @@ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl: } } -template -static void upscale(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { - int index = item_ct1.get_local_id(0) + - item_ct1.get_group(0) * item_ct1.get_local_range(0); - if (index >= ne10 * ne11 * ne12 * ne13) { - return; - } - // operation - int i10 = index % ne10; - int i11 = (index / ne10) % ne11; - int i12 = (index / (ne10 * ne11)) % ne12; - int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - - int i00 = static_cast(i10 / sf0); - int i01 = static_cast(i11 / sf1); - int i02 = static_cast(i12 / sf2); - int i03 = static_cast(i13 / sf3); - - dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); -} - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -392,20 +368,6 @@ static void arange_kernel(T * dst, const int k, T start, T step, } } -template -static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, queue_ptr stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); - sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); - }); -} - template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -505,42 +467,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c } } -template -static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; - const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; - const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; - const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; - switch (dst->type) { - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward(args)...); - break; - } - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward(args)...); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - template static inline void ggml_sycl_op_unary( ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) { @@ -784,15 +710,6 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor }); } -static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, - [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, - int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, - queue_ptr stream) { - ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); - }); -} - static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float min_val; float max_val; @@ -1131,12 +1048,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sqr(ctx, dst); } -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_upscale(ctx, dst); -} - - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_clamp(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 7c71974687a..997132166ab 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -71,8 +71,6 @@ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 12819705849..2ec1421841b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -44,7 +44,6 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" -#include "ggml-sycl/gated_delta_net.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml-sycl/norm.hpp" @@ -4863,9 +4862,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: - return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-sycl/upscale.cpp b/ggml/src/ggml-sycl/upscale.cpp new file mode 100644 index 00000000000..18c743de447 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.cpp @@ -0,0 +1,410 @@ +#include "upscale.hpp" + +static void upscale_f32(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + int index = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *((const float*)((const char*)x + i03 * nb03 + i02 * nb02 + + i01 * nb01 + i00 * nb00)); +} + +static void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int) sycl::floor((float) y_src_f); + int y1_src = y0_src + 1; + + y0_src = sycl::max(0, sycl::min(y0_src, ne01_src - 1)); + y1_src = sycl::max(0, sycl::min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = sycl::max(0.0f, sycl::min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int) sycl::floor(x_src_f); + int x1_src = x0_src + 1; + + x0_src = sycl::max(0, sycl::min(x0_src, ne00_src - 1)); + x1_src = sycl::max(0, sycl::min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = sycl::max(0.0f, sycl::min(dx, 1.0f)); + + const float* p_a = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_b = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_c = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_d = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static void upscale_f32_bilinear_antialias(const float * src0, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = sycl::max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = sycl::max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = sycl::max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = sycl::min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = sycl::max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = sycl::min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return sycl::max(1.0f - sycl::fabs(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = + *(const float*)((const char*)src0 + sx * nb00 + sy * nb01 + + i02_src * nb02 + i03_src * nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + +namespace bicubic_interpolation { +static float weight1(float x, const float &a) { return ((a + 2) * x - (a + 3)) * x * x + 1; }; +static float weight2(float x, const float &a) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; }; + +static float bicubic(float p0, float p1, float p2, float p3, float x, float a) { + const float w0 = weight2(x + 1, a); + const float w1 = weight1(x + 0, a); + const float w2 = weight1(1 - x, a); + const float w3 = weight2(2 - x, a); + return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3; +}; + +} + +static void upscale_f32_bicubic(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const float a = -0.75f; + using bicubic_interpolation::bicubic; + + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = + ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + const int y0_src = (int) sycl::floor((float) y_src_f); + const float dy = y_src_f - (float)y0_src; + + const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + const int x0_src = (int) sycl::floor((float) x_src_f); + const float dx = x_src_f - (float)x0_src; + + const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03; + + auto load = [=](int x_off, int y_off) -> float { + int i00_src = sycl::max(0, sycl::min(x0_src + x_off, ne00_src - 1)); + int i01_src = sycl::max(0, sycl::min(y0_src + y_off, ne01_src - 1)); + return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01); + }; + + const float result = bicubic( + bicubic(load(-1, -1), load(0, -1), load(1, -1), load(2, -1), dx, a), + bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx, a), + bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx, a), + bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx, a), + dy, + a); + + dst[index] = result; +} + +static void upscale_f32_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); + }); +} + +static void upscale_f32_bilinear_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + bool antialias, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + if (antialias) { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32_bilinear_antialias( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32_bilinear( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, + ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } +} + +static void upscale_f32_bicubic_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32_bicubic( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + }); + } +} + +void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 + ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) + : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 + ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) + : sf1; + pixel_offset = 0.0f; + } + + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + upscale_f32_bilinear_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + upscale_f32_bicubic_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_upscale(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/upscale.hpp b/ggml/src/ggml-sycl/upscale.hpp new file mode 100644 index 00000000000..c36c1bdc970 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include +#include "dpct/helper.hpp" +#include "common.hpp" + +#define SYCL_UPSCALE_BLOCK_SIZE 256 + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); From 49adc8b470cb1c09d01d313b6fc1859d43658158 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Mar 2026 10:09:59 +0100 Subject: [PATCH 007/249] vulkan: allow graphics queue only through env var (llama/20599) * vulkan: avoid graphics queue on non-RADV AMD drivers * avoid graphics queues on small GPUs * change to only use graphics queue if overridden with env var GGML_VK_ALLOW_GRAPHICS_QUEUE * reenable transfer queue if graphics queue is not used --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7092361d2ea..e9b6778d628 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4981,8 +4981,9 @@ static vk_device ggml_vk_get_device(size_t idx) { std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues - // On AMD, the graphics queue seems to be faster, so don't avoid it - const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; + // Allow overriding avoiding the graphics queue because it can increase performance on RADV + const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr); + const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1); const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1); @@ -5443,11 +5444,14 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); + // Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled + const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue; + if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); - device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); + device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); From ab7d305b751ddd8c50f3beeddb95eaf02d19741a Mon Sep 17 00:00:00 2001 From: Justin Bradford Date: Tue, 17 Mar 2026 05:03:54 -0700 Subject: [PATCH 008/249] kleidiai : fix MUL_MAT support for batched (3D) inputs (llama/20620) * kleidiai : fix MUL_MAT support for batched (3D) inputs The supports_op() check incorrectly rejected MUL_MAT operations with 3D inputs (ne[2] > 1), but the actual compute_forward_qx() implementation handles batched inputs correctly via a loop over ne12. This caused models with Q4_0/Q8_0 weights to crash during graph scheduling when n_seq_max > 1, because weights were placed in KLEIDIAI buffers during loading (tested with 2D inputs) but the runtime used 3D inputs. Also relax the buffer check to allow supports_op() to be called during weight loading when src[0]->buffer is NULL. Fixes #20608 * Kleidiai support_ops should only return true for 3D inputs, not also 4D --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 7a5924944a8..0ecf7ae02ac 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1461,7 +1461,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) && - ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { + ggml_ne(op->src[1], 3) == 1) { return true; } } From 0ad6ceef59777414829eb8167e4e12022c03be21 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Mar 2026 14:27:23 +0100 Subject: [PATCH 009/249] vulkan: async and event fixes (llama/20518) * vulkan: fix event wait submission, event command buffer reset * fix event command buffer reset validation error * also reset command buffers before reuse * use timeline semaphores instead of fences for event_synchronize * don't use initializer list for semaphore wait info * use multiple events to avoid reset issues * fix event reuse issue with multiple vectors * add semaphore wait condition also if compute_ctx already exists * remove event pending stage --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 128 ++++++++++++++++++++------- 1 file changed, 95 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e9b6778d628..3d8ce10676e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -191,6 +191,7 @@ struct vk_queue; struct vk_command_buffer { vk::CommandBuffer buf; + uint64_t use_counter = 0; bool in_use = false; }; @@ -938,19 +939,24 @@ struct vk_subbuffer { } }; -// vk_event is used for the event-related backend interfaces. It uses 'event' for -// event_wait and 'fence' for event_synchronize. Polling on an event for +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +// vk_event is used for the event-related backend interfaces. It uses vk::Events for +// event_wait and a timeline semaphore for event_synchronize. Polling on an event for // event_synchronize wouldn't be sufficient to wait for command buffers to complete, // and would lead to validation errors. struct vk_event { + std::vector events_free; // Events available for reuse + std::vector events_submitted; // Events that are fully submitted and can be reused on next synchronize vk::Event event; - vk::Fence fence; - vk_command_buffer* cmd_buffer = nullptr; -}; + bool has_event; -struct vk_semaphore { - vk::Semaphore s; - uint64_t value; + vk_semaphore tl_semaphore; + vk_command_buffer* cmd_buffer = nullptr; + uint64_t cmd_buffer_use_counter = 0; }; struct vk_submission { @@ -2319,7 +2325,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); - p.cmd_buffers.push_back({ cmd_buffers.front(), true }); + p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true }); return &p.cmd_buffers[p.cmd_buffers.size()-1]; } @@ -2788,6 +2794,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ); } +static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) { + VK_LOG_DEBUG("ggml_vk_set_event()"); + + ctx->s->buffer->buf.resetEvent( + event, + ctx->p->q->stage_flags + ); +} + static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) { VK_LOG_DEBUG("ggml_vk_set_event()"); @@ -6396,6 +6411,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer( static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) { for (auto& cmd_buffer : pool.cmd_buffers) { if (!cmd_buffer.in_use) { + cmd_buffer.use_counter++; cmd_buffer.in_use = true; return &cmd_buffer; } @@ -6500,14 +6516,15 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { } static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) { + vk_context result; if (!ctx->compute_ctx.expired()) { - return ctx->compute_ctx.lock(); - } - - vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + result = ctx->compute_ctx.lock(); + } else { + result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = result; - ggml_vk_ctx_begin(ctx->device, result); + ctx->compute_ctx = result; + ggml_vk_ctx_begin(ctx->device, result); + } if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { result->s->wait_semaphores.push_back(ctx->transfer_semaphore); @@ -13801,6 +13818,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { ctx->submit_pending = false; if (cmd_buf) { cmd_buf->in_use = false; + cmd_buf->buf.reset(); } } @@ -14862,18 +14880,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset - // the backend interface doesn't have an explicit reset, so reset it here - // before we record the command to set it - ctx->device->device.resetEvent(vkev->event); - ctx->device->device.resetFences({ vkev->fence }); + if (vkev->has_event) { + // Move existing event into submitted + vkev->events_submitted.push_back(vkev->event); + } + + // Grab the next event and record it, create one if necessary + if (vkev->events_free.empty()) { + vkev->event = ctx->device->device.createEvent({}); + } else { + vkev->event = vkev->events_free.back(); + vkev->events_free.pop_back(); + } + + vkev->has_event = true; ggml_vk_set_event(compute_ctx, vkev->event); + vkev->tl_semaphore.value++; + compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore); ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(compute_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; vkev->cmd_buffer = cmd_buf; + vkev->cmd_buffer_use_counter = cmd_buf->use_counter; ctx->compute_ctx.reset(); } @@ -14884,9 +14915,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - ggml_vk_wait_events(compute_ctx, {vkev->event}); - ggml_vk_ctx_end(compute_ctx); - ctx->compute_ctx.reset(); + if (vkev->has_event) { + // Wait for latest event + ggml_vk_wait_events(compute_ctx, { vkev->event }); + } } // TODO: enable async and synchronize @@ -15676,10 +15708,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t return nullptr; } - // The event/fence is expected to initially be in the signaled state. - vkev->event = device->device.createEvent({}); - vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled}); - device->device.setEvent(vkev->event); + // No events initially, they get created on demand + vkev->has_event = false; + + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 }; return new ggml_backend_event { /* .device = */ dev, @@ -15693,8 +15728,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe vk_event *vkev = (vk_event *)event->context; - device->device.destroyFence(vkev->fence); - device->device.destroyEvent(vkev->event); + device->device.destroySemaphore(vkev->tl_semaphore.s); + for (auto& event : vkev->events_free) { + device->device.destroyEvent(event); + } + for (auto& event : vkev->events_submitted) { + device->device.destroyEvent(event); + } + if (vkev->has_event) { + device->device.destroyEvent(vkev->event); + } delete vkev; delete event; } @@ -15705,10 +15748,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm auto device = ggml_vk_get_device(ctx->device); vk_event *vkev = (vk_event *)event->context; - VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); - // Finished using current command buffer so we flag for reuse - if (vkev->cmd_buffer) { - vkev->cmd_buffer->in_use = false; + // Only do something if the event has actually been used + if (vkev->has_event) { + vk::Semaphore sem = vkev->tl_semaphore.s; + uint64_t val = vkev->tl_semaphore.value; + vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val}; + VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize"); + + // Reset and move submitted events + for (auto& event : vkev->events_submitted) { + device->device.resetEvent(event); + } + vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end()); + vkev->events_submitted.clear(); + + // Finished using current command buffer so we flag for reuse + if (vkev->cmd_buffer) { + // Only flag for reuse if it hasn't been reused already + if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) { + vkev->cmd_buffer->in_use = false; + vkev->cmd_buffer->buf.reset(); + } + vkev->cmd_buffer = nullptr; + } } } From c890a9d9b4f6ad9c9a75387a0b0d3c973ad7f4ca Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Tue, 17 Mar 2026 19:03:40 +0500 Subject: [PATCH 010/249] ggml-cpu: fix RVV checks in quants and repacking (llama/20682) * ggml-cpu: refactor quants.c; add rvv check * ggml-cpu: refactor; disable generic fallback --- ggml/src/ggml-cpu/arch/riscv/quants.c | 40 +++++++++++++++++-------- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 40 ++++--------------------- ggml/src/ggml-cpu/repack.cpp | 3 ++ 3 files changed, 35 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 826055dd9a4..d7e9ba46348 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -115,10 +115,10 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { assert(k % QK_K == 0); - block_q8_K * y_blocks = (block_q8_K *)y; size_t nb = k / QK_K; #if defined(__riscv_v_intrinsic) + block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); for (size_t i = 0; i < nb; i++) { @@ -2052,6 +2052,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2147,6 +2148,7 @@ static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2163,6 +2165,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2269,6 +2272,7 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2285,6 +2289,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -2488,6 +2493,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2507,7 +2513,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined(__riscv_v) +#if defined(__riscv_v_intrinsic) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -2542,7 +2548,6 @@ static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, }; -#endif static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); @@ -2618,6 +2623,7 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2634,6 +2640,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2818,6 +2825,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2830,10 +2838,11 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const break; } #else - ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); @@ -2928,6 +2937,7 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t } *s = sumf; } +#endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2944,6 +2954,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3036,6 +3047,7 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size } *s = 0.25f * sumf; } +#endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3052,6 +3064,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3161,6 +3174,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_ *s = sumf; } +#endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3177,6 +3191,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3190,7 +3205,6 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ const int nb = n / QK_K; -#if defined __riscv_v_intrinsic const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); float sumf = 0; int acc[4]; @@ -3252,14 +3266,8 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ } *s = sumf; - -#else - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif } +#endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3276,6 +3284,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3381,6 +3390,7 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3397,6 +3407,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3467,6 +3478,7 @@ static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3483,6 +3495,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3592,6 +3605,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3604,6 +3618,6 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo break; } #else - return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index cd5807879ea..c37488cae54 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -107,8 +107,7 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR } #else UNUSED(nb); - UNUSED(y); - ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); + ggml_quantize_mat_q8_0_4x8_generic(x, vy, k); #endif } @@ -203,6 +202,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -222,7 +222,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); @@ -256,9 +255,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -280,7 +276,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -392,9 +387,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -416,7 +408,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -451,9 +442,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -476,7 +464,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); UNUSED(bs); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); @@ -505,9 +492,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -679,9 +663,9 @@ void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } // End K-Block __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); - } } +#endif void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -909,6 +893,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -929,7 +914,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -994,9 +978,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1019,7 +1000,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1267,9 +1247,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1292,7 +1269,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); @@ -1355,9 +1331,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1380,7 +1353,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1429,9 +1401,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1731,3 +1700,4 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } } +#endif diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 6b76ab3bfb1..f18758f16bb 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1365,6 +1365,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } } +// Only enable these for RISC-V. #if defined __riscv_zvfh void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1568,6 +1569,7 @@ void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert(nc % 16 == 0); UNUSED(bs); + UNUSED(nr); const int nb = n / QK_K; const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; @@ -2381,6 +2383,7 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, } } +// Only enable these for RISC-V. #if defined __riscv_zvfh void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; From 906aef3da84894d4b59e4f94d5fe69bc6fd0f01b Mon Sep 17 00:00:00 2001 From: Kevin Hannon Date: Tue, 17 Mar 2026 13:16:49 -0400 Subject: [PATCH 011/249] ggml-blas: set mkl threads from thread context (llama/20602) * ggml blas: set mkl threads from thread context * add code to run blas locally --- ggml/src/ggml-blas/ggml-blas.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 5de64b816fc..e7a1763b54d 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -121,6 +121,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg bli_thread_set_num_threads(ctx->n_threads); #elif defined(GGML_BLAS_USE_NVPL) nvpl_blas_set_num_threads(ctx->n_threads); +#elif defined(GGML_BLAS_USE_MKL) + mkl_set_num_threads(ctx->n_threads); #endif for (int64_t i13 = 0; i13 < ne13; i13++) { From 16ca5e6fb130cd68d1d499db8b59361e3aba0db6 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Mar 2026 21:51:43 +0100 Subject: [PATCH 012/249] vulkan: disable mmvq on Intel Windows driver (llama/20672) * vulkan: disable mmvq on Intel Windows driver * improve comment --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3d8ce10676e..3e36435d166 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7646,20 +7646,14 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: - if (k < 2048) { + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { + // Intel Windows proprietary driver MMVQ performance is worse than fp16, see + // https://github.com/ggml-org/llama.cpp/issues/17628 return false; } - if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { - // Intel Windows proprietary driver tuning - switch (src0_type) { - case GGML_TYPE_MXFP4: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return false; - default: - return true; - } + if (k < 2048) { + return false; } switch (src0_type) { From e222814fc4bef846054b071b48bb54e89fcc00c5 Mon Sep 17 00:00:00 2001 From: Krishna Sridhar <99914379+srikris-sridhar@users.noreply.github.com> Date: Tue, 17 Mar 2026 15:34:36 -0700 Subject: [PATCH 013/249] hexagon: add neg, exp, sigmoid, softplus ops, cont, repeat ops (llama/20701) Add element-wise unary ops needed by Qwen 3.5's DeltaNet linear attention layers. These ops follow the existing unary-ops pattern with VTCM DMA double-buffering. - neg: negate via scale by -1.0 - exp: uses existing hvx_exp_f32 HVX intrinsics - sigmoid: uses existing hvx_sigmoid_f32_aa HVX intrinsics - softplus: log(1 + exp(x)) scalar fallback - CONT reuses the existing CPY infrastructure since making a tensor contiguous is equivalent to a same-type copy. - REPEAT implements tiled memory copy with multi-threaded execution via the worker pool, supporting f32 and f16 types. The kernel parallelizes across output rows and uses memcpy for each tile. Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 152 ++++++++++++++++++++--- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-msg.h | 5 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/hvx-base.h | 2 + ggml/src/ggml-hexagon/htp/hvx-exp.h | 17 +-- ggml/src/ggml-hexagon/htp/hvx-sigmoid.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 45 +++++++ ggml/src/ggml-hexagon/htp/repeat-ops.c | 148 ++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/softmax-ops.c | 2 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 95 ++++++++++++++ 11 files changed, 441 insertions(+), 28 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/repeat-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 19917cb1140..4b8a16c3635 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2362,6 +2362,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, return n_bufs; } +static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + // CONT is just a contiguous copy — reuse CPY op + req->op = HTP_OP_CPY; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + +static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_REPEAT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2449,12 +2470,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf break; case GGML_OP_UNARY: - if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { + switch (ggml_get_unary_op(t)) { + case GGML_UNARY_OP_SILU: req->op = HTP_OP_UNARY_SILU; supported = true; - } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { + break; + case GGML_UNARY_OP_GELU: req->op = HTP_OP_UNARY_GELU; supported = true; + break; + case GGML_UNARY_OP_SIGMOID: + req->op = HTP_OP_UNARY_SIGMOID; + supported = true; + break; + case GGML_UNARY_OP_NEG: + req->op = HTP_OP_UNARY_NEG; + supported = true; + break; + case GGML_UNARY_OP_EXP: + req->op = HTP_OP_UNARY_EXP; + supported = true; + break; + case GGML_UNARY_OP_SOFTPLUS: + req->op = HTP_OP_UNARY_SOFTPLUS; + supported = true; + break; + default: + break; } break; @@ -2640,16 +2682,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: - if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || - (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { - ggml_hexagon_dispatch_op(sess, node, flags); + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: + break; } break; case GGML_OP_GLU: - if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || - (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { - ggml_hexagon_dispatch_op(sess, node, flags); + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: + break; } break; case GGML_OP_SOFT_MAX: @@ -2676,6 +2730,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CONT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_REPEAT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + case GGML_OP_ARGSORT: ggml_hexagon_dispatch_op(sess, node, flags); break; @@ -3006,6 +3068,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, return true; } +static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + + // CONT is same-type only, supports f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + return true; +} + +static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // Support f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + // src and dst must be the same type + if (src0->type != dst->type) return false; + + // dst dims must be multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0) return false; + if (dst->ne[1] % src0->ne[1] != 0) return false; + if (dst->ne[2] % src0->ne[2] != 0) return false; + if (dst->ne[3] % src0->ne[3] != 0) return false; + + // require contiguous tensors (no transposition) + if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false; + + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3063,21 +3158,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_UNARY: - { - const auto unary_op = ggml_get_unary_op(op); - if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; case GGML_OP_GLU: - { - const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3098,6 +3204,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_CONT: + supp = ggml_hexagon_supported_cont(sess, op); + break; + + case GGML_OP_REPEAT: + supp = ggml_hexagon_supported_repeat(sess, op); + break; + case GGML_OP_ARGSORT: supp = ggml_hexagon_supported_argsort(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 02d07a503d5..a490a2ce9a1 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + repeat-ops.c argsort-ops.c ssm-conv.c ) diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 52dcc36d8f7..56bc5b622c5 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -53,6 +53,10 @@ enum htp_op { HTP_OP_RMS_NORM, HTP_OP_UNARY_SILU, HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_GEGLU, @@ -69,6 +73,7 @@ enum htp_op { HTP_OP_SQRT, HTP_OP_SUM_ROWS, HTP_OP_SSM_CONV, + HTP_OP_REPEAT, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 2ef20936f1b..f643fdc340d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 578ca288fb6..3e6a8579b1f 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include "hex-utils.h" #include "hvx-types.h" diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h index 44dfe232a3d..84e4836dc92 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.h +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -3,6 +3,7 @@ #include #include +#include #include "hvx-base.h" #include "hvx-floor.h" @@ -16,8 +17,8 @@ #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x41a00000) // 20.0 -#define EXP_RANGE_L (0xc1a00000) // -20.0 +#define EXP_RANGE_R (0x42B16666) // 88.7 +#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector z_qf32_v; @@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector temp_v = in_vec; - // Clamp inputs to (-20.0, 20.0) + // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec); epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); @@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { // normalize before every QFloat's vmpy x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + // z = x * x; z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); - x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); - // y = E4 + E5 * x; E_const = Q6_V_vsplat_R(EXP_COEFF_5); y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); @@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max return Q6_V_vmux_QVV(pred0, inf, out); } -static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { +static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict HVX_Vector vec_out = Q6_V_vzero(); static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) + static const float kMaxExp = 88.7f; const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector inf = hvx_vec_splat_f32(kInf); diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 095193277ea..37f3e7b6fae 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -2,6 +2,7 @@ #define HVX_SIGMOID_H #include "hvx-base.h" +#include "hvx-inverse.h" #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 3f99dbb32c4..2a3f9e562b7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -516,6 +516,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = op_repeat(&octx); + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -1090,6 +1123,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_SQR: case HTP_OP_SQRT: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_SOFTPLUS: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -1175,6 +1212,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_REPEAT: + if (n_bufs != 2) { + FARF(ERROR, "Bad repeat-req buffer list"); + continue; + } + proc_repeat_req(ctx, &req, bufs); + break; + case HTP_OP_ARGSORT: if (n_bufs != 2) { FARF(ERROR, "Bad argsort-req buffer list"); diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c new file mode 100644 index 00000000000..5db06c920e2 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -0,0 +1,148 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +struct htp_repeat_context { + struct htp_ops_context * octx; + + uint32_t nr0; + uint32_t nr1; + uint32_t nr2; + uint32_t nr3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; // ne1 * ne2 * ne3 + + size_t type_size; +}; + +static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; + struct htp_ops_context * octx = rctx->octx; + const struct htp_tensor * src = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + const uint32_t ne00 = src->ne[0]; + const uint32_t ne01 = src->ne[1]; + const uint32_t ne02 = src->ne[2]; + const uint32_t ne03 = src->ne[3]; + + const uint32_t nb00 = src->nb[0]; + const uint32_t nb01 = src->nb[1]; + const uint32_t nb02 = src->nb[2]; + const uint32_t nb03 = src->nb[3]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb0 = dst->nb[0]; + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t nr0 = rctx->nr0; + const uint32_t nr1 = rctx->nr1; + const uint32_t nr2 = rctx->nr2; + const uint32_t nr3 = rctx->nr3; + + const size_t row_bytes = ne00 * rctx->type_size; + + const uint32_t row_start = rctx->nrows_per_thread * ith; + const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + // Decompose flat dst row index into (i1, i2, i3) + const uint32_t i1 = dst_row % ne1; + const uint32_t i2 = (dst_row / ne1) % ne2; + const uint32_t i3 = dst_row / (ne1 * ne2); + + // Map to source indices (tiling) + const uint32_t k1 = i1 % ne01; + const uint32_t k2 = i2 % ne02; + const uint32_t k3 = i3 % ne03; + + const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03; + uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + // Tile along dimension 0 + for (uint32_t i0 = 0; i0 < nr0; i0++) { + uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0; + memcpy(dst_ptr, src_row, row_bytes); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_repeat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + struct htp_tensor * dst = &octx->dst; + + // Validate that dst dims are multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0 || + dst->ne[1] % src0->ne[1] != 0 || + dst->ne[2] % src0->ne[2] != 0 || + dst->ne[3] % src0->ne[3] != 0) { + FARF(ERROR, "repeat: dst dims must be multiples of src dims\n"); + return HTP_STATUS_INVAL_PARAMS; + } + + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + case HTP_TYPE_F16: type_size = 2; break; + default: + FARF(ERROR, "repeat: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_repeat_context rctx = { + .octx = octx, + .nr0 = dst->ne[0] / src0->ne[0], + .nr1 = dst->ne[1] / src0->ne[1], + .nr2 = dst->ne[2] / src0->ne[2], + .nr3 = dst->ne[3] / src0->ne[3], + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + }; + + FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3); + + worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 8dae7f1ed55..d6356b9506f 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src, const float max) { hvx_sub_scalar_f32(spad, src, max, num_elems); - hvx_exp_f32(spad, dst, num_elems, false); + hvx_exp_f32(dst, spad, num_elems, false); float sum = hvx_reduce_sum_f32(dst, num_elems); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 5bbd5040d3d..3d0928d4dce 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -9,6 +9,8 @@ #include #include "hex-dma.h" +#include "hvx-exp.h" +#include "hvx-sigmoid.h" #include "hvx-utils.h" #define GGML_COMMON_DECL_C @@ -166,6 +168,75 @@ static void sqrt_f32(const float * restrict src, } } +static void neg_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f); + } +} + +static void exp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_exp_f32(dst_local, src_local, row_elems, false); + } +} + +static void sigmoid_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_sigmoid_f32_aa(dst_local, src_local, row_elems); + } +} + +static void softplus_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + // softplus(x) = log(1 + exp(x)) + // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + for (uint32_t i = 0; i < row_elems; i++) { + float x = src_f[i]; + // For x > 20: softplus(x) ≈ x (avoids exp overflow) + dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x)); + } + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -247,6 +318,18 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_SQRT: sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_NEG: + neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_EXP: + exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SIGMOID: + sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SOFTPLUS: + softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; default: break; } @@ -295,6 +378,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_SQRT: op_type = "sqrt-f32"; break; + case HTP_OP_UNARY_NEG: + op_type = "neg-f32"; + break; + case HTP_OP_UNARY_EXP: + op_type = "exp-f32"; + break; + case HTP_OP_UNARY_SIGMOID: + op_type = "sigmoid-f32"; + break; + case HTP_OP_UNARY_SOFTPLUS: + op_type = "softplus-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); From 61c7cd024dd371952f3dae27266eaf7bf82f2f04 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 18 Mar 2026 09:53:13 +0100 Subject: [PATCH 014/249] HIP : ignore return of hipMemAdvise [no ci] (llama/20696) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5a0be4a472a..a31e843e153 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -126,7 +126,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) if (err == hipSuccess) { // hipMemAdviseSetCoarseGrain is an optional performance hint; // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs). - cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); + (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); (void)hipGetLastError(); // clear any error } From 14caedfa18bfcf75888661221117db591897b40b Mon Sep 17 00:00:00 2001 From: Shaw Nguyen <49144872+mrshaw01@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:45:06 +0700 Subject: [PATCH 015/249] ggml-cpu/x86: fix unused changemask warning in repack (llama/20692) --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 33c6cb65098..af1cebad131 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -531,7 +531,6 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t UNUSED(bs); - __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); // Permute mask used for easier vector processing at later stages @@ -580,6 +579,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t if constexpr ( std::is_same_v || std::is_same_v) { + const __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); } else if constexpr (std::is_same_v) { // Load 8 E8M0 exponents and convert to float via LUT From d6a0f0d075a2732e30031408e843fbbb712a860f Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 18 Mar 2026 10:23:47 -0700 Subject: [PATCH 016/249] Move to no timeout for WaitAny in graph submission to avoid deadlocks in some cases on llvm-pipe backends (llama/20618) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 60 +++++++++++----------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 128b7dc3de8..3976a171d16 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -509,50 +509,39 @@ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & subs, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. if (subs.empty()) { return; } - while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { + + bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; + while (blocking_wait) { + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0); + if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); #endif subs.erase(subs.begin()); } + blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; } if (subs.empty()) { return; } - if (block) { - for (auto & sub : subs) { - while (!sub.submit_done.completed) { - auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); - ggml_backend_webgpu_handle_wait_status(waitStatus); - } -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); -#endif - } - subs.clear(); - } else { - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - ggml_backend_webgpu_handle_wait_status(waitStatus, true); + // Poll each submit future once and remove completed submissions. + for (auto sub = subs.begin(); sub != subs.end();) { + auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); + bool success = ggml_backend_webgpu_handle_wait_status(waitStatus, true); #ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (sub->submit_done.completed && sub->profile_futures.empty()) { + ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); + if (success && sub->profile_futures.empty()) { #else - if (sub->submit_done.completed) { + if (success) { #endif - sub = subs.erase(sub); - } else { - ++sub; - } + sub = subs.erase(sub); + } else { + ++sub; } } } @@ -2961,17 +2950,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { - /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ - ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ - ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ - ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ - ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, + /* .is_host = */ NULL, // defaults to false }, /* .device = */ - dev, - /* .context = */ - NULL + dev, + /* .context = */ NULL }; return &ggml_backend_webgpu_buffer_type; From dfba84cb470ec2c4d750936b048460648aea7db6 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 19 Mar 2026 11:02:42 +0800 Subject: [PATCH 017/249] CANN: support flash attention for head dim not multiple of 16, fix ALiBi slope offset (llama/20031) - Allow FLASH_ATTN_EXT when head dimension D is not a multiple of 16 by padding Q/K/V to D_padded = GGML_PAD(D, 16), running FusedInferAttentionScoreV2, then slicing the output back to D (ggml-cann.cpp + aclnn_ops.cpp). - Fix aclnn_get_slope second-part offset: use ggml_type_size(dtype) instead of sizeof(float) so ALiBi slopes are correct when dtype is F16 (e.g. GQA with 48 heads); fixes buffer overflow and large numerical errors in those cases. --- ggml/src/ggml-cann/aclnn_ops.cpp | 78 ++++++++++++++++++++++++++++---- ggml/src/ggml-cann/ggml-cann.cpp | 4 -- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index fc7c3e3b724..4b7aab1e72d 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, end = 2 * ((n_head - 1) - n_head_log2) + 1; step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, - dtype); + aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1, + step, dtype); } } @@ -3599,6 +3599,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); + // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16 + // (required by FusedInferAttentionScoreV2) + const int64_t D = src0->ne[0]; + const int64_t D_padded = GGML_PAD(D, 16); + const bool needs_padding = (D != D_padded); + + ggml_cann_pool_alloc q_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc k_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc v_pad_allocator(ctx.pool()); + + if (needs_padding) { + int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 }; + + auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne, + ggml_cann_pool_alloc & allocator) { + int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] }; + size_t pad_nb[GGML_MAX_DIMS]; + pad_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1]; + } + int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3]; + void * buffer = allocator.alloc(nelements * faElemSize); + acl_tensor_ptr padded = + ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS); + aclnn_pad(ctx, tensor.get(), padded.get(), paddings); + tensor = std::move(padded); + }; + + pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator); + pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator); + pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator); + + src0_bsnd_ne[0] = D_padded; + src1_bsnd_ne[0] = D_padded; + src2_bsnd_ne[0] = D_padded; + } + // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp acl_tensor_ptr bcast_pse_tensor; @@ -3688,17 +3726,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); acl_tensor_ptr fa_dst_tensor; - acl_tensor_ptr acl_dst_tensor; ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - if (dst->type == GGML_TYPE_F32) { - void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - + if (dst->type == GGML_TYPE_F32 || needs_padding) { int64_t * out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for (int i = 1; i < GGML_MAX_DIMS; ++i) { out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } + int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3]; + void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize); fa_dst_tensor = ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); @@ -3730,8 +3767,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst nullptr // softmaxLse ); - if (dst->type == GGML_TYPE_F32) { - // Step 6: post-processing, permute and cast to f32 + // Step 6: post-processing — slice padded output and/or cast to f32 + if (needs_padding) { + ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool()); + + if (dst->type == GGML_TYPE_F32) { + int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] }; + size_t sliced_nb[GGML_MAX_DIMS]; + sliced_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1]; + } + int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3]; + void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize); + acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize, + sliced_ne, sliced_nb, GGML_MAX_DIMS); + + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get()); + + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } else { + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get()); + } + } else if (dst->type == GGML_TYPE_F32) { acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3f3de9f0bcb..a682746bb42 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2503,10 +2503,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten // different head sizes of K and V are not supported yet return false; } - if (op->src[0]->ne[0] % 16 != 0) { - // TODO: padding to support - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); if (logitSoftcap != 0.0f) { From 12015a2174ad014cffdafddb0158875c3de8aed5 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 19 Mar 2026 13:08:35 +0900 Subject: [PATCH 018/249] ggml-webgpu: Add supports for `DIAG` and `TRI` (llama/20664) * Add supports for DIAG and TRI. * Remove extra ttype and add a comment for TRI op. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 37 ++++++++++++++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 8 ++++ ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 35 ++++++++++++++---- 3 files changed, 68 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3d7e59fddf3..ad665e4de93 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -244,13 +244,15 @@ struct ggml_webgpu_binary_pipeline_key_hash { /** Unary **/ struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + ggml_tri_type ttype; // only used for GGML_OP_TRI bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace && + ttype == other.ttype; } }; @@ -261,6 +263,7 @@ struct ggml_webgpu_unary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.is_unary); ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.ttype); return seed; } }; @@ -1058,6 +1061,7 @@ class ggml_webgpu_shader_lib { .op = op, .is_unary = is_unary, .inplace = context.inplace, + .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), }; auto it = unary_pipelines.find(key); @@ -1088,6 +1092,29 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } + if (op == GGML_OP_TRI) { + switch (key.ttype) { + case GGML_TRI_TYPE_LOWER: + defines.push_back("TRI_TYPE_LOWER"); + variant += "_tri_type_lower"; + break; + case GGML_TRI_TYPE_LOWER_DIAG: + defines.push_back("TRI_TYPE_LOWER_DIAG"); + variant += "_tri_type_lower_diag"; + break; + case GGML_TRI_TYPE_UPPER: + defines.push_back("TRI_TYPE_UPPER"); + variant += "_tri_type_upper"; + break; + case GGML_TRI_TYPE_UPPER_DIAG: + defines.push_back("TRI_TYPE_UPPER_DIAG"); + variant += "_tri_upper_diag"; + break; + default: + GGML_ABORT("Unsupported ggml_tri_type for unary shader"); + } + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_unary, defines); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3976a171d16..4b0eeac0f42 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2209,6 +2209,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_DIAG: + case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); @@ -3201,6 +3203,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_COS: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; + case GGML_OP_DIAG: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_TRI: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index feaf6d0ac29..21beb9bb94d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -5,7 +5,6 @@ enable f16; #define TYPE f32 #endif - @group(0) @binding(0) var src: array; @@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; + let ne2 = params.ne2; +#ifdef DIAG + let ne1 = params.ne0; +#else + let ne1 = params.ne1; +#endif + let ne0 = params.ne0; + + let i3 = i / (ne2 * ne1 * ne0); + i = i % (ne2 * ne1 * ne0); + let i2 = i / (ne1 * ne0); + i = i % (ne1 * ne0); + let i1 = i / ne0; + let i0 = i % ne0; let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; @@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res_f32 = cos(f32(src[params.offset_src + src_idx])); let res = TYPE(res_f32); #endif +#ifdef DIAG + let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1); +#endif +#ifdef TRI +#ifdef TRI_TYPE_LOWER + let res = select(0.0, src[params.offset_src + src_idx], i0 < i1); +#elif TRI_TYPE_LOWER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1); +#elif TRI_TYPE_UPPER + let res = select(0.0, src[params.offset_src + src_idx], i0 > i1); +#elif TRI_TYPE_UPPER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1); +#endif +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res; From 3d004fbf0af918d2bcca0b43b63371dc9666446a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Sat, 28 Mar 2026 11:47:59 +0200 Subject: [PATCH 019/249] ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (llama/20665) * Update the preprocessor of RMS_NORM and add L2_NORM. * Fix the name of rms_norm to row_norm. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 60 ++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 +++--- .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 97 +++++++++++++++++++ 3 files changed, 171 insertions(+), 16 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index ad665e4de93..9d16abf20d7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { } }; +/** Row Norm **/ + +struct ggml_webgpu_row_norm_pipeline_key { + ggml_op op; + bool inplace; + + bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { + return op == other.op && inplace == other.inplace; + } +}; + +struct ggml_webgpu_row_norm_pipeline_key_hash { + size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib { std::unordered_map argsort_pipelines; // key is order std::unordered_map argsort_merge_pipelines; // key is order std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + row_norm_pipelines; // op/inplace std::unordered_map get_rows_pipelines; // src_type, vectorized std::unordered_map @@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib { return sum_rows_pipelines[1]; } + webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_row_norm_pipeline_key key = { + .op = context.dst->op, + .inplace = context.inplace, + }; + + auto it = row_norm_pipelines.find(key); + if (it != row_norm_pipelines.end()) { + return it->second; + } + std::vector defines; + std::string variant; + + switch (key.op) { + case GGML_OP_RMS_NORM: + defines.push_back("OP_RMS_NORM"); + variant = "rms_norm"; + break; + case GGML_OP_L2_NORM: + defines.push_back("OP_L2_NORM"); + variant = "l2_norm"; + break; + default: + GGML_ABORT("Unsupported op for row_norm shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + return row_norm_pipelines[key]; + } + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { bool vec4 = context.src0->ne[0] % 4 == 0; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4b0eeac0f42..f7973df682a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -366,7 +366,6 @@ struct webgpu_context_struct { std::map> cpy_pipelines; // src_type, dst_type - std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split @@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); +static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), @@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, - entries, ggml_nrows(src)); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -2192,7 +2198,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: - return ggml_webgpu_rms_norm(ctx, src0, node); + case GGML_OP_L2_NORM: + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: @@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } -static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); -} - static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); @@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl new file mode 100644 index 00000000000..7777944941c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -0,0 +1,97 @@ +#ifdef INPLACE +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var params: Params; +#else +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + dst[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#endif + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src/dst + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +@group(0) @binding(0) +var src: array; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + sum += pow(src[i_src_row + col], 2.0); + col += WG_SIZE; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + +#ifdef OP_RMS_NORM + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); +#elif OP_L2_NORM + let scale = 1.0/max(sqrt(sum), params.eps); +#endif + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_src_row + col, i_dst_row + col, scale); + col += WG_SIZE; + } +} From 2a6de29364870e524efc88fd0a470139ddf28332 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 19 Mar 2026 14:05:01 +0800 Subject: [PATCH 020/249] CANN: handle in-place ROPE on non-contiguous f32 tensors (llama/20274) RotaryPositionEmbedding on CANN fails when src and dst share the same non-contiguous buffer (inplace + view), because the operator overwrites source data before it is fully read. Add a branch that detects this case and uses contiguous temporary buffers: copy src to temp, run ROPE into another temp, then copy back to the non-contiguous dst. Fixes 20 failing ROPE tests (f32, v=1, inplace=1). Signed-off-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cann/aclnn_ops.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 4b7aab1e72d..9b736636def 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2943,6 +2943,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // Rotate full tensor (no tail), using trans tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); + } else if (src0->data == dst->data && !ggml_is_contiguous(src0)) { + // In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely + // read and write the same non-contiguous buffer. Use contiguous temporaries. + size_t contiguous_nb[GGML_MAX_DIMS]; + contiguous_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1]; + } + int64_t total_elements = ggml_nelements(src0); + ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float)); + ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float)); + + acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float), + src0->ne, contiguous_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float), + dst->ne, contiguous_nb, GGML_MAX_DIMS); + + cann_copy(ctx, acl_src.get(), acl_src_contig.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get()); + cann_copy(ctx, acl_dst_contig.get(), acl_dst.get()); } else { // Rotate full tensor (no tail), using original tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), From fea629d00f46863f76a34a5e5f37c98cf7043524 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Thu, 19 Mar 2026 09:14:48 +0100 Subject: [PATCH 021/249] cmake : fix build warning when kleidiai is enabled (llama/20457) * cmake : fix build warning when kleidiai is enabled * remove LLAMA_ARG_THREADS from KleidiAI backend --- ggml/src/ggml-cpu/CMakeLists.txt | 36 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6ca3176a2f2..7c062a62995 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -570,24 +570,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") - if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) - endif() - - # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ - # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 - FetchContent_Declare(KleidiAI_Download + set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} DOWNLOAD_EXTRACT_TIMESTAMP NEW - URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) + URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} + ) - FetchContent_GetProperties(KleidiAI_Download - SOURCE_DIR KLEIDIAI_SRC - POPULATED KLEIDIAI_POPULATED) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28") + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + EXCLUDE_FROM_ALL + ) - if (NOT KLEIDIAI_POPULATED) - FetchContent_Populate(KleidiAI_Download) + FetchContent_MakeAvailable(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + else() + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + ) + + FetchContent_GetProperties(KleidiAI_Download + SOURCE_DIR KLEIDIAI_SRC + POPULATED KLEIDIAI_POPULATED + ) + + if (NOT KLEIDIAI_POPULATED) + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + endif() endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) From 43c7c0f86c09455bf2acdb7384bcaf351d35c564 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:32:04 +0000 Subject: [PATCH 022/249] vulkan: dequantize iq4_xs 4 at a time (llama/20657) --- .../ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl | 13 +++++++------ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index ce7f2d699a2..3f494eb4d5a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -444,19 +444,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint iq = 16 * ib32 + 2 * (idx % 8); + const uint ib = idx / 64; // 4 values per idx + const uint ib32 = (idx % 64) / 8; // 0..7 + const uint iq = 4 * ib32 + (idx % 4); const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; - const uint qshift = (idx & 8) >> 1; - u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; + const uint qshift = idx & 4; + u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F); const float d = float(data_a[ib].d); - const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4b00ba3debb..abd2a9c36fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -554,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_quant = "2"; if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { From 551bb8296008094dbab71b8450f6434e6045ba35 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 19 Mar 2026 08:45:28 -0700 Subject: [PATCH 023/249] ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET) + GET_ROWS optimization (llama/20687) * Implement l2_norm, set, tri * Add DIAG/SOLVE_TRI * Add SSM_CONV * Better get_rows and gated_delta_net to support qwen3.5 * Clean up, update ops.md * Fix binding_index type for wasm * Fix read write annotations * cleanups --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 277 ++++++++++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 314 +++++++++++++++++- .../wgsl-shaders/gated_delta_net.wgsl | 132 ++++++++ .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 31 +- .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 5 +- ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl | 109 ++++++ .../ggml-webgpu/wgsl-shaders/solve_tri.wgsl | 121 +++++++ .../ggml-webgpu/wgsl-shaders/ssm_conv.wgsl | 65 ++++ 8 files changed, 1034 insertions(+), 20 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d16abf20d7..59861ac16cc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_ssm_conv_shader_decisions { + uint32_t block_size; + uint32_t tokens_per_wg; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions { uint32_t wg_size; }; +/** Set **/ + +struct ggml_webgpu_set_pipeline_key { + ggml_type type; + bool inplace; + + bool operator==(const ggml_webgpu_set_pipeline_key & other) const { + return type == other.type && inplace == other.inplace; + } +}; + +struct ggml_webgpu_set_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Get Rows **/ struct ggml_webgpu_get_rows_pipeline_key { @@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash { } }; +/** Solve Tri **/ +struct ggml_webgpu_solve_tri_pipeline_key { + int type; + int n; + int k; + + bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const { + return type == other.type && n == other.n && k == other.k; + } +}; + +struct ggml_webgpu_solve_tri_pipeline_key_hash { + size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.n); + ggml_webgpu_hash_combine(seed, key.k); + return seed; + } +}; + +/** SSM Conv **/ +struct ggml_webgpu_ssm_conv_pipeline_key { + int type; + int vectorized; + + bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const { + return type == other.type && vectorized == other.vectorized; + } +}; + +/** Gated Delta Net **/ +struct ggml_webgpu_gated_delta_net_pipeline_key { + int type; + int s_v; + int kda; + + bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const { + return type == other.type && s_v == other.s_v && kda == other.kda; + } +}; + +struct ggml_webgpu_gated_delta_net_pipeline_key_hash { + size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.s_v); + ggml_webgpu_hash_combine(seed, key.kda); + return seed; + } +}; + +struct ggml_webgpu_ssm_conv_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + /** Scale **/ struct ggml_webgpu_scale_pipeline_key { @@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib { unary_pipelines; // type/op/inplace std::unordered_map scale_pipelines; // inplace + std::unordered_map + solve_tri_pipelines; // type + std::unordered_map + ssm_conv_pipelines; // type/vectorized + std::unordered_map + gated_delta_net_pipelines; // type/S_v/kda std::unordered_map - pad_pipelines; // circular/non-circular + pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap std::unordered_map - concat_pipelines; // type + concat_pipelines; // type std::unordered_map - repeat_pipelines; // type + repeat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_map set_rows_pipelines; + std::unordered_map set_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib { switch (key.op) { case GGML_OP_RMS_NORM: - defines.push_back("OP_RMS_NORM"); + defines.push_back("RMS_NORM"); variant = "rms_norm"; break; case GGML_OP_L2_NORM: - defines.push_back("OP_L2_NORM"); + defines.push_back("L2_NORM"); variant = "l2_norm"; break; default: @@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - + const uint32_t row_norm_wg_size = 128u; + uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); auto processed = preprocessor.preprocess(wgsl_row_norm, defines); row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); return row_norm_pipelines[key]; @@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib { return set_rows_pipelines[key]; } + webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + + auto it = set_pipelines.find(key); + if (it != set_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "set"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for set shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_set, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + set_pipelines[key] = pipeline; + return set_pipelines[key]; + } + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = cumsum_pipelines.find(1); if (it != cumsum_pipelines.end()) { @@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib { switch (key.src_type) { case GGML_TYPE_F32: + defines.push_back("FLOAT_PARALLEL"); if (key.vectorized) { defines.push_back("F32_VEC"); defines.push_back("SRC_TYPE=vec4"); @@ -709,6 +846,7 @@ class ggml_webgpu_shader_lib { variant += "_f32"; break; case GGML_TYPE_F16: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("F16"); defines.push_back("SRC_TYPE=f16"); defines.push_back("DST_TYPE=f32"); @@ -716,6 +854,7 @@ class ggml_webgpu_shader_lib { variant += "_f16"; break; case GGML_TYPE_I32: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("I32"); defines.push_back("SRC_TYPE=i32"); defines.push_back("DST_TYPE=i32"); @@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib { return scale_pipelines[key]; } + webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_solve_tri_pipeline_key key = { + .type = context.dst->type, + .n = (int) context.src0->ne[0], + .k = (int) context.src1->ne[0], + }; + + auto it = solve_tri_pipelines.find(key); + if (it != solve_tri_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "solve_tri"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for solve_tri shader"); + } + + const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size); + const uint32_t k_tile = wg_size; + const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES; + const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row); + + defines.push_back(std::string("N=") + std::to_string(key.n)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("K_TILE=") + std::to_string(k_tile)); + defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n)); + + auto processed = preprocessor.preprocess(wgsl_solve_tri, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + solve_tri_pipelines[key] = pipeline; + return solve_tri_pipelines[key]; + } + + webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_conv_pipeline_key key = { + .type = context.dst->type, + .vectorized = context.src1->ne[0] == 4, + }; + + auto it = ssm_conv_pipelines.find(key); + if (it != ssm_conv_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_conv"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_conv shader"); + } + + if (key.vectorized) { + defines.push_back("VECTORIZED"); + variant += "_vec4"; + } + + constexpr uint32_t block_size = 32u; + constexpr uint32_t tokens_per_wg = 8u; + + defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u"); + defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u"); + + auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines); + auto decisions = std::make_shared(); + decisions->block_size = block_size; + decisions->tokens_per_wg = tokens_per_wg; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_conv_pipelines[key] = pipeline; + return ssm_conv_pipelines[key]; + } + + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_gated_delta_net_pipeline_key key = { + .type = context.dst->type, + .s_v = (int) context.src2->ne[0], + .kda = context.src3->ne[0] == context.src2->ne[0], + }; + + auto it = gated_delta_net_pipelines.find(key); + if (it != gated_delta_net_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "gated_delta_net"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for gated_delta_net shader"); + } + + if (key.kda) { + defines.push_back("KDA"); + variant += "_kda"; + } + + defines.push_back("S_V=" + std::to_string(key.s_v) + "u"); + defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u"); + + auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + gated_delta_net_pipelines[key] = pipeline; + return gated_delta_net_pipelines[key]; + } + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7973df682a..5e16f84ddd2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -880,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g params, entries, wg_x); } +static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const bool inplace = ggml_webgpu_tensor_equal(src0, dst); + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); + const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + 1u, + (uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size), + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector entries; + uint32_t binding_index = 0; + if (!inplace) { + entries.push_back({ .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + binding_index++; + } + entries.push_back({ .binding = binding_index, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back({ .binding = binding_index + 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup @@ -935,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); + const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + token_tiles, + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); + const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .src3 = src3, + .src4 = src4, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); + + const uint32_t s_v = (uint32_t) src2->ne[0]; + const uint32_t h = (uint32_t) src2->ne[1]; + const uint32_t n_tokens = (uint32_t) src2->ne[2]; + const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const float scale = 1.0f / sqrtf((float) s_v); + uint32_t scale_u32; + memcpy(&scale_u32, &scale, sizeof(scale_u32)); + + std::vector params = { + h, + n_tokens, + n_seqs, + s_v * h * n_tokens * n_seqs, + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[3] / ggml_type_size(src2->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) src0->ne[1], + (uint32_t) (src2->ne[3] / src0->ne[3]), + scale_u32, + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(src3), + .offset = ggml_webgpu_tensor_align_offset(ctx, src3), + .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, + { .binding = 4, + .buffer = ggml_webgpu_tensor_buf(src4), + .offset = ggml_webgpu_tensor_align_offset(ctx, src4), + .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, + { .binding = 5, + .buffer = ggml_webgpu_tensor_buf(src5), + .offset = ggml_webgpu_tensor_align_offset(ctx, src5), + .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, + { .binding = 6, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs); +} + static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, @@ -1016,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { + const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; + ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .src1 = nullptr, @@ -1060,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); + uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); + uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); + uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; + uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1632,7 +1901,7 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, - .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .inplace = inplace, }; @@ -2176,6 +2445,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_CPY: case GGML_OP_CONT: return ggml_webgpu_cpy(ctx, src0, node); + case GGML_OP_SET: + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: @@ -2219,6 +2490,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_DIAG: case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SOLVE_TRI: + return ggml_webgpu_solve_tri(ctx, src0, src1, node); + case GGML_OP_SSM_CONV: + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_GATED_DELTA_NET: + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2957,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm /* .is_host = */ NULL, // defaults to false }, /* .device = */ - dev, + dev, /* .context = */ NULL }; @@ -3040,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); break; + case GGML_OP_SET: + supports_op = src0->type == src1->type && src0->type == op->type && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); + break; case GGML_OP_SET_ROWS: supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); @@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } } break; + case GGML_OP_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_DIAG: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_SOLVE_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + break; + case GGML_OP_SSM_CONV: + supports_op = op->type == GGML_TYPE_F32; + break; + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t s_v = (uint32_t) src2->ne[0]; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 && + op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 && + s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + } + break; case GGML_OP_CLAMP: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; @@ -3201,12 +3503,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_COS: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; - case GGML_OP_DIAG: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); - break; - case GGML_OP_TRI: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); - break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl new file mode 100644 index 00000000000..f9d98fda40b --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -0,0 +1,132 @@ +@group(0) @binding(0) +var src_q: array; + +@group(0) @binding(1) +var src_k: array; + +@group(0) @binding(2) +var src_v: array; + +@group(0) @binding(3) +var src_g: array; + +@group(0) @binding(4) +var src_beta: array; + +@group(0) @binding(5) +var src_state: array; + +@group(0) @binding(6) +var dst: array; + +struct Params { + h: u32, + n_tokens: u32, + n_seqs: u32, + s_off: u32, + + sq1: u32, + sq2: u32, + sq3: u32, + + sv1: u32, + sv2: u32, + sv3: u32, + + sb1: u32, + sb2: u32, + sb3: u32, + + neq1: u32, + rq3: u32, + scale: f32, +}; + +@group(0) @binding(7) +var params: Params; + +var sh_k: array; +var sh_q: array; +var sh_g: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let head_id = workgroup_id.x; + let seq_id = workgroup_id.y; + let col = local_id.x; + + let iq1 = head_id % params.neq1; + let iq3 = seq_id / params.rq3; + + let state_size = S_V * S_V; + let state_base = (seq_id * params.h + head_id) * state_size; + + var state: array; + for (var i = 0u; i < S_V; i++) { + state[i] = src_state[state_base + col * S_V + i]; + } + + var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; + + for (var t = 0u; t < params.n_tokens; t++) { + let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1; + let k_off = q_off; + let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1; + let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1; + + sh_q[col] = src_q[q_off + col]; + sh_k[col] = src_k[k_off + col]; + +#ifdef KDA + let g_base = gb_off * S_V; + sh_g[col] = exp(src_g[g_base + col]); +#endif + + workgroupBarrier(); + + let v_val = src_v[v_off + col]; + let beta_val = src_beta[gb_off]; + + var kv_col = 0.0; + var delta_col = 0.0; + var attn_col = 0.0; + +#ifdef KDA + for (var i = 0u; i < S_V; i++) { + kv_col += (sh_g[i] * state[i]) * sh_k[i]; + } + + delta_col = (v_val - kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#else + let g_val = exp(src_g[gb_off]); + + for (var i = 0u; i < S_V; i++) { + kv_col += state[i] * sh_k[i]; + } + + delta_col = (v_val - g_val * kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = g_val * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#endif + + dst[attn_off + col] = attn_col * params.scale; + attn_off += S_V * params.h; + + workgroupBarrier(); + } + + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_base + col * S_V + i] = state[i]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index b10800e36d2..d9eb6a3567e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -640,6 +640,35 @@ var params: Params; @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { +#ifdef FLOAT_PARALLEL + let blocks_per_row = params.ne0 / BLOCK_SIZE; + let row_count = params.n_rows * params.ne2 * params.ne3; + + if (gid.x >= blocks_per_row * row_count) { + return; + } + + let block_idx = gid.x % blocks_per_row; + var row_idx = gid.x / blocks_per_row; + let i_dst3 = row_idx / (params.ne2 * params.n_rows); + + row_idx = row_idx % (params.ne2 * params.n_rows); + let i_dst2 = row_idx / params.n_rows; + let i_dst1 = row_idx % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + copy_elements(i_src_row, i_dst_row, block_idx); +#else if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; } @@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } +#endif } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index 7777944941c..bd8d32bded7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -81,11 +81,12 @@ fn main(@builtin(workgroup_id) wid: vec3, } sum = scratch[0]; -#ifdef OP_RMS_NORM +#ifdef RMS_NORM let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); -#elif OP_L2_NORM +#elif defined(L2_NORM) let scale = 1.0/max(sqrt(sum), params.eps); #endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl new file mode 100644 index 00000000000..0a7ae9bdb2c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl @@ -0,0 +1,109 @@ +#ifdef TYPE_I32 +#define TYPE i32 +#else +#define TYPE f32 +#endif + +#ifndef INPLACE +@group(0) @binding(0) +var src0: array; +#define SRC1_BINDING 1 +#else +#define SRC1_BINDING 0 +#endif + +#define DST_BINDING SRC1_BINDING + 1 +#define PARAMS_BINDING SRC1_BINDING + 2 + +@group(0) @binding(SRC1_BINDING) +var src1: array; + +@group(0) @binding(DST_BINDING) +var dst: array; + +struct Params { + ne: u32, + offset_src0: u32, + offset_src1: u32, + offset_view: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst10: u32, + stride_dst11: u32, + stride_dst12: u32, + stride_dst13: u32, + + src1_ne0: u32, + src1_ne1: u32, + src1_ne2: u32, + src1_ne3: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var params: Params; + +fn decode_src1_coords(idx: u32) -> vec4 { + var i = idx; + let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0; + let i3 = i / plane; + i = i % plane; + let row = params.src1_ne1 * params.src1_ne0; + let i2 = i / row; + i = i % row; + let i1 = i / params.src1_ne0; + let i0 = i % params.src1_ne0; + return vec4(i0, i1, i2, i3); +} + +fn decode_view_coords(rel: u32) -> vec4 { + let i3 = rel / params.stride_dst13; + let rem3 = rel % params.stride_dst13; + let i2 = rem3 / params.stride_dst12; + let rem2 = rem3 % params.stride_dst12; + let i1 = rem2 / params.stride_dst11; + let i0 = rem2 % params.stride_dst11; + return vec4(i0, i1, i2, i3); +} + +fn view_rel_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 + + coords.z * params.stride_dst12 + coords.w * params.stride_dst13; +} + +fn src1_idx_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_src10 + coords.y * params.stride_src11 + + coords.z * params.stride_src12 + coords.w * params.stride_src13; +} + +fn in_set_view(rel: u32, coords: vec4) -> bool { + return view_rel_from_coords(coords) == rel; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + +#ifdef INPLACE + let coords = decode_src1_coords(gid.x); + + let src1_idx = params.offset_src1 + src1_idx_from_coords(coords); + let dst_idx = params.offset_view + view_rel_from_coords(coords); + + dst[dst_idx] = src1[src1_idx]; +#else + let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view); + let coords = decode_view_coords(rel); + + if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) { + dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)]; + } else { + dst[gid.x] = src0[params.offset_src0 + gid.x]; + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl new file mode 100644 index 00000000000..9d5d902cb1e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl @@ -0,0 +1,121 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src00: u32, + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + k: u32, + ne2: u32, + ne3: u32, +}; + +@group(0) @binding(3) +var params: Params; + +var shA: array; +var shB: array; + +fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src0 + + col * params.stride_src00 + + row * params.stride_src01 + + i2 * params.stride_src02 + + i3 * params.stride_src03; +} + +fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src1 + + col * params.stride_src10 + + row * params.stride_src11 + + i2 * params.stride_src12 + + i3 * params.stride_src13; +} + +fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_dst + + col * params.stride_dst0 + + row * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let batch = workgroup_id.y; + let col = workgroup_id.x * WG_SIZE + local_id.x; + let i3 = batch / params.ne2; + let i2 = batch % params.ne2; + let active_lane = local_id.x < K_TILE; + let active_col = active_lane && col < params.k; + + var X: array; + + for (var row_base = 0u; row_base < N; row_base += BATCH_N) { + let cur_n = min(BATCH_N, N - row_base); + + for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) { + let tile_row = i / N; + let tile_col = i % N; + shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)]; + } + + for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) { + let tile_row = i / K_TILE; + let tile_col = i % K_TILE; + let global_col = workgroup_id.x * WG_SIZE + tile_col; + let sh_idx = tile_row * K_TILE + tile_col; + + if (global_col < params.k) { + shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)]; + } else { + shB[sh_idx] = 0.0; + } + } + + workgroupBarrier(); + + if (active_col) { + for (var row_offset = 0u; row_offset < cur_n; row_offset++) { + let r = row_base + row_offset; + var b = shB[row_offset * K_TILE + local_id.x]; + let a_row = row_offset * N; + + for (var t = 0u; t < r; t++) { + b -= shA[a_row + t] * X[t]; + } + + let x = b / shA[a_row + r]; + X[r] = x; + dst[dst_idx(r, col, i2, i3)] = x; + } + } + + workgroupBarrier(); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl new file mode 100644 index 00000000000..11511305ed8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl @@ -0,0 +1,65 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src01: u32, + stride_src02: u32, + stride_src11: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + + nc: u32, + nr: u32, + n_t: u32, + n_s: u32, + token_tiles: u32, +}; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i1 = gid.x; + let tile_y = gid.y / TOKENS_PER_WG; + let local_token = gid.y % TOKENS_PER_WG; + let i3 = tile_y / params.token_tiles; + let token_tile = tile_y % params.token_tiles; + let i2 = token_tile * TOKENS_PER_WG + local_token; + + if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) { + return; + } + + let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01; + let src1_base = params.offset_src1 + i1 * params.stride_src11; + + var sum = 0.0; + +#ifdef VECTORIZED + sum = + src0[src0_base + 0u] * src1[src1_base + 0u] + + src0[src0_base + 1u] * src1[src1_base + 1u] + + src0[src0_base + 2u] * src1[src1_base + 2u] + + src0[src0_base + 3u] * src1[src1_base + 3u]; +#else + for (var i0 = 0u; i0 < params.nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } +#endif + + let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0; + dst[dst_idx] = sum; +} From 081dc773a5bbc9c119ccb7ec94a5fca332ccd0d5 Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 19 Mar 2026 17:05:44 +0100 Subject: [PATCH 024/249] ci : add hip quality check (llama/20430) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CI: add hip quality check * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update .github/workflows/hip-quality-check.yml Co-authored-by: Sigbjørn Skjæret * Update .github/workflows/hip-quality-check.yml Co-authored-by: Sigbjørn Skjæret * Update .github/workflows/hip-quality-check.yml Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Revert "Update .github/workflows/hip-quality-check.yml" This reverts commit efa0bfcdb01dfac0feee674987a0482d50f46145. * scripts: gcn-cdna-vgpr-check.py: enforce int type for total_vgprs * scripts: gcn-cdna-vgpr-check.py: add flash attention instances to ignore list * Bump ccache version * Add mssing seperators to list --------- Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-hip/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index b44ed0f7215..c2357722629 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -53,9 +53,6 @@ endif() message(STATUS "HIP and hipBLAS found") -# Workaround old compilers -set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024") - file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") From 15f6b6ad76ef6aed7814d1212bfc17a8e05e0937 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Thu, 19 Mar 2026 09:11:06 -0700 Subject: [PATCH 025/249] hexagon: add Matrix Extensions (HMX) for Hexagon NPU backend (llama/20693) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * migrate(vtcm): unify VTCM management for HMX merge - Add HMX fields to htp_context (#ifdef HTP_HAS_HMX): hmx_enabled, hmx_dma, vtcm_scratch_size, exp2_table - Add HTP_VTCM_SESSION_HOLD CMake option (default ON): hold VTCM for entire session instead of per-op acquire/release - Add vtcm_op_acquire/vtcm_op_release inline wrappers: no-op in session-hold mode, delegate in per-op mode - Add VTCM tail reservation for precompute tables (256KB, 64KB aligned) in htp_iface_start under HTP_HAS_HMX - Add HMX init/cleanup hooks in htp_iface_start/stop - Add precompute table recovery in vtcm_acquire after VTCM preemption - Do NOT migrate vtcm_mgr from htp-ops-lib (replaced by tail reservation) * migrate(repack): replace x4x2 with HMX tile-permuted super-block format - Add hmx_block_q4_0/q8_0 struct definitions (scales-first + sequential quants) - Implement forward repack: repack_q4_0_to_hmx_superblock, repack_q8_0_to_hmx_superblock, repack_f16_to_tile_permuted - Implement inverse repack for get_tensor debug verification - Route set_tensor/get_tensor via opt_arch >= 73 to HMX path, else existing HVX x4x2 - MXFP4 on v73+ falls back to HVX x4x2 repack (not memcpy) - Extend supports_op: add IQ4_NL for v73+, F16 tile alignment checks - Tail blocks (K not multiple of 256): repack to x4x2 via pad-repack-truncate - Add CMake GGML_HEXAGON_HMX_TAIL_HVX option (default ON); OFF rejects non-256-aligned K in supports_op * migrate(dma): add dma_queue_push_1d() convenience wrapper for HMX ops Add 1D linear DMA transfer helper to hex-dma.h for upcoming HMX op migration. Reuses existing dma_queue_flush() for sync points instead of adding redundant dma_queue_drain(). * migrate(hmx): reorganize HMX files into htp/hmx/ and simplify HMX locking Move all 14 HMX-related files from htp/ to htp/hmx/ subdirectory for cleaner separation between HVX and HMX code. Simplify HMX hardware locking by replacing the two-level lock design (SHARED HAP lock + custom asm spin-lock) with direct HAP_compute_res_hmx_lock/unlock on the existing vtcm_rctx, which already has HMX capability. Key changes: - Create htp/hmx/ subdirectory with all HMX infrastructure and ops - Replace hmx_mgr_ctx_id + spin-lock with HAP_compute_res_hmx_lock(vtcm_rctx) - Remove hmx_manager_enable/disable_execution() (SHARED lock no longer needed) - Add hmx_set_vtcm_state() call in main.c (was missing, caused null globals) - Update main.c includes to use hmx/ prefix - Clean up duplicate declarations from hmx-worker-pool.h * migrate(hmx-infra): consolidate HMX infrastructure into htp_context - Remove hmx-mgr.c/h: eliminate global HMX state singleton, thread htp_context through all HMX ops - Remove hmx-worker-pool.c/h: replace separate HMX worker pool with main worker_pool API (worker_pool_run_func) - Replace hmx_unit_acquire/release with direct HAP_compute_res_hmx_lock/unlock on ctx->vtcm_rctx - Remove HTP_VTCM_SESSION_HOLD compile option: always use per-op vtcm_acquire/release - Remove hmx_dma from htp_context: HMX ops use ctx->dma[0] instead of separate DMA queue - Simplify main.c init/cleanup: remove hmx_manager_setup/reset and vtcm_op_acquire/release wrappers - Delete upstream llama.cpp AGENTS.md (not applicable to fork) * migrate(flash-attn): remove HTP_EXP2_TABLE_COPIES, use single exp2 table - Remove HTP_EXP2_TABLE_COPIES compile definition and CMake cache variable - Remove table duplication loop in precompute-table.c - Remove worker_index % N sub-table indexing in hmx-flash-attn-ops.c - Fix table_size to 65536 (single 64 KB copy) in main.c The exp2 lookup table is read-only; concurrent VTCM reads do not cause bank conflicts, so duplicating the table wastes 192 KB of VTCM for no benefit. * migrate(dsp-main): add HMX priority dispatch in packet_callback - Add proc_hmx_matmul_req() wrapper for HMX mat_mul (F16 and quantized types) - Add proc_hmx_flash_attn_req() wrapper for HMX simple_flash_attn (FP16 only, falls back to HVX for non-FP16) - Add proc_hmx_rms_norm_req() wrapper using hvx_rms_norm_f32 - Route MUL_MAT, FLASH_ATTN_EXT, RMS_NORM through HMX path when ctx->hmx_enabled - Split RMS_NORM and SCALE into separate case blocks for independent dispatch - All HMX wrappers guarded by #ifdef HTP_HAS_HMX * migrate(cmake-dsp): add HMX source files and -mhmx for v73+ skels Add HTP_VTCM_SESSION_HOLD option (default ON) and v73+ HMX build integration: compile hmx-matmul-ops, hmx-flash-attn-ops, hmx-rms-norm-ops and precompute-table into v73/v75/v79/v81 skels with -mhmx flag and HTP_HAS_HMX=1 definition. v68/v69 skels remain unchanged. * migrate(hmx-ops): fix compile errors in HMX ops for ggml struct compatibility - hmx-matmul-ops.c: include ggml-common.h for block_q4_0/block_q8_0 definitions - hmx-matmul-ops.c: rename quants->qs, scale->d to match upstream ggml field names - hmx-flash-attn-ops.c: suppress -Wunused-function/-Wunused-variable warnings - hmx-flash-attn-ops.c: inline ctx->n_threads, remove unused n_workers variable * hmx: set Q/O element type to fp16 for flash attention The llama.cpp integration passes fp16 Q/O tensors, so qo_fp32_element should be false to match the actual data layout. * hexagon: unify HMX weight format to x4x2, add IQ4_NL and DSP-side fallback Remove the v73+ HMX-specific super-block/tile-permuted weight format and unify all architectures on the HVX x4x2 packed format. The DSP now decides at runtime whether to use the HMX or HVX matmul path based on dimension constraints (M%32, N%32, K%256 alignment), rather than the host rejecting ops in supports_op. This simplifies the host repack logic, eliminates ~400 lines of HMX super-block code, and adds IQ4_NL quantization support across host and DSP. Key changes: - Remove hmx_block_q4_0/q8_0 types, repack functions, and F16 tile permutation (ggml-hexagon.cpp, hmx-quants.h) - Simplify set_tensor/get_tensor to always use x4x2 repack, add IQ4_NL - Force is_host=false so tensor copies go through format conversion - Add HTP_TYPE_IQ4_NL to DSP message protocol (htp-msg.h) - Rewrite DSP dequantizers to work directly on x4x2 layout (hmx-matmul-ops.c) - Fix mxclracc.hf placement: clear per output tile, not once globally - Move HMX eligibility checks to DSP proc_hmx_matmul_req (main.c) - Remove dma_queue_push_1d wrapper, use 2D DMA for weight sub-blocks - Add VTCM allocation overflow asserts - Remove GGML_HEXAGON_HMX_TAIL_HVX build option (CMakeLists.txt) * Enhance HMX debugging capabilities with new tile dumping functions - Introduced hmx_dump_tile_mem and hmx_dump_fp32_tile_region for improved memory layout visualization of tile data. - Updated hmx_dump_tile_rows to provide raw memory output for debugging. - Added debug logging for activation and weight tile pairs during processing to facilitate troubleshooting. - Refined existing macros for dumping HVX vector values to streamline debugging output. These changes aim to enhance the debugging experience for HMX matmul operations, ensuring better visibility into data handling and transformations. * OK for small mat mul * hexagon: fix UDMA roiwidth 16-bit overflow in HMX matmul DMA transfers The UDMA descriptor roiwidth field is 16-bit (max 65535), but large matrix DMA transfers (e.g. 32×2304 = 73728 bytes) exceeded this limit, causing truncated transfers and NaN results. Fix by using 2D DMA (per-row stride × n_rows) instead of 1D (total_size × 1) for all 4 DMA push calls in both x4x2 and fp16 weight paths. Also includes: - Use standard vlut16 instead of _nomatch variant for dequantization - Add per-tile vscatter drain barrier for correctness - Add compile-time HMX_DEBUG_TRACE_VALUES instrumentation (disabled by default) * hexagon: remove HMX RMS norm fallback and re-enable matmul pipeline Remove hmx-rms-norm-ops.c as the HVX RMS norm offers no benefit over the generic unary path. Re-enable DMA pipeline mode for QK matmul. * hexagon: guard all HMX matmul DMA transfers against UDMA 16-bit field overflow All UDMA type1 descriptor fields (roiwidth, roiheight, srcstride, dststride) are 16-bit (max 65535). Commit 40d2a9cc fixed roiwidth overflow in the non-pipeline path by switching from 1D to 2D DMA, but the pipeline path (3 call sites) was left unchanged and still used 1D DMA with chunk_size = n_cols * row_stride as roiwidth, which overflows for any practical matrix size when the pipeline is active. Add a local hmx_dma_push_safe() helper that transparently handles overflow: - Fast path (zero overhead): all params fit in 16 bits -> direct call. - Contiguous block: reshapes into a single 2D descriptor with sub_width that fits in 16 bits, preserving async DMA behavior. - Stride overflow: row-by-row fallback for future large-k models where per-row stride itself exceeds 65535. Convert all 8 external dma_queue_push calls in hmx-matmul-ops.c to use the safe helper, including the 3 pipeline sites (1D -> 2D fix), the FP16 and x4x2 weight paths, qweight_fetch sub-block DMA, and the output-stationary activation fetch. * hexagon: multithread activation/output transfer and add HMX matmul fallback - Replace single-threaded transfer_activation_chunk_fp32_to_fp16 with transfer_activation_chunk_multithread across all HMX matmul paths - Add multi-threaded transfer_output_chunk_multithread for FP16-to-FP32 output store, following the same worker pool pattern - Rename transfer_activation_chunk_no_prefetch back to transfer_activation_chunk_fp32_to_fp16 and clean up stale comments - Add HVX fallback in proc_hmx_matmul_req when HMX matmul returns error * [todo]: dynamic alloc vtcm, cause prefill regression. * hexagon: constrain HMX mxmem tile load region to avoid VTCM bank boundary faults Set activation/weight mxmem Rt to 2047 for single-tile loads and document the 4MB VTCM bank boundary constraint, preventing precise bus errors when dynamic VTCM allocation places tiles near bank edges. * hexagon: split unaligned-M HMX matmul into HMX+HVX phases - keep HMX for the 32-aligned head rows and process tail rows with HVX - force re-quantization for HVX tail after HMX phase to avoid stale VTCM state - preserve fallback behavior when N is unaligned or no aligned M rows exist * hexagon: batch-4 Q4_0 dequantize fast path and remove debug traces Add dequantize_x4x2_q4_0_x4groups_hvx() that processes 4 contiguous K-tiles with a single vmemu + vlut16 per row, reducing per-tile overhead. The dequantize loop now takes the batch-4 path when 4 aligned K-tiles are available within the same column tile, falling back to the original single-tile path otherwise. Also removes HMX_DEBUG_TRACE_VALUES instrumentation blocks that are no longer needed. * hexagon: abort on DSP error and fix HMX-to-HVX fallback quantize flag Promote DSP response error from log to GGML_ABORT for fail-fast behavior. Clear SKIP_QUANTIZE flag when falling back from HMX to HVX matmul so the HVX path correctly re-quantizes activations. * hexagon: support batch matmul. This fix perplexity issue The problem comes from Grouped-Query Attention(GQA). Strides between batches are not well respected TODO: optimize batch matmul to reuse weights between batches. * hexagon: reuse weights in fp16 batch matmul * hexagon: remove unused HMX flash attention operations and precomputation table, remove the log system for test * hexagon: remove unused HVX math helpers, debug infrastructure, and stale build options * hexagon: fix HMX not enabled due to missing force_hvx parameter in IDL * hexagon: remove the unnecessary changes not related to HMX * hexagon: bypass HMX by default * hexagon: add upstream repo link to htp-ops-lib ported file headers * hexagon: restore host buffer support * hexagon: add HMX=1 option for the adb scripts * hex-hmx: improve DMA pipelining * hex-hmx: further improvements to dma pipelining * hex-hmx: minor cleanup * hex-hmx: move hmx lock out of inner loops/calls * hex-hmx: remove unnecessary state and wrappers * hex-hmx: remove hmx dir and unify f32 to f16 conversions * hex-hmx: further unify hvx conversions * hex-hmx: revert f16 converter to the original for now * hex-hmx: minor cleanup for f16 to f32 converter * hex-mm: replace incorrect fp16-to-fp32 hmx converter and reformated related code * hex-dma: move chanied dma push into hex-dma.h header and update hmx-mm * hex-mm: use hex_is_aligned instead of a duplicated hmx_is_aligned * hex-mm: use hvx_vec_splat_f16 in the hmx code * hex-mm: use VLEN and HTP types in hmx-code * hex-mm: remove duplicate QK and defs * hexagon: pre-shuffle quants before vlut16 * hexagon: enable HMX by default * hex-mm: code indent fixes for hmx-matmul * hexagon: update hex-utils to include align/smin/etc helpers and use that in hmx mm * hex-mm: more formatting fixes * hex-mm: minor naming updates in hmx code * hex-mm: remove leftover from rebase conflict * Fix the incorrect indents --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 7 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 18 + ggml/src/ggml-hexagon/htp/hex-dma.h | 80 + ggml/src/ggml-hexagon/htp/hex-utils.h | 22 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 1528 ++++++++++++++++++++ ggml/src/ggml-hexagon/htp/hmx-ops.h | 72 + ggml/src/ggml-hexagon/htp/hmx-profile.h | 34 + ggml/src/ggml-hexagon/htp/hmx-utils.h | 88 ++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 6 + ggml/src/ggml-hexagon/htp/htp-msg.h | 19 +- ggml/src/ggml-hexagon/htp/htp_iface.idl | 2 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 37 +- ggml/src/ggml-hexagon/htp/main.c | 246 +++- 13 files changed, 2142 insertions(+), 17 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.h create mode 100644 ggml/src/ggml-hexagon/htp/hmx-profile.h create mode 100644 ggml/src/ggml-hexagon/htp/hmx-utils.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4b8a16c3635..8bcf5291c11 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -45,6 +45,7 @@ static int opt_verbose = 0; static int opt_profile = 0; static int opt_hostbuf = 1; // hostbuf ON by default static int opt_experimental = 0; +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only // Enable all stages by default static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; @@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { // Start the DSP-side service. We need to pass the queue ID to the // DSP in a FastRPC call; the DSP side will import the queue and start // listening for packets in a callback. - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); @@ -3372,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); @@ -3381,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; opt_opsync = str_opsync ? atoi(str_opsync) : 0; opt_profile = str_profile ? atoi(str_profile) : 0; - opt_etm = str_etm ? atoi(str_etm) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index a490a2ce9a1..6ddfe4252f5 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -40,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) +# HMX acceleration: available on v73+ architectures +set(HTP_HMX_VERSIONS v73 v75 v79 v81) +list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) + +if (_hmx_idx GREATER_EQUAL 0) + target_sources(${HTP_LIB} PRIVATE + hmx-matmul-ops.c + ) + + # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) + set_source_files_properties( + hmx-matmul-ops.c + PROPERTIES COMPILE_OPTIONS "-mhmx" + ) + + target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) +endif() + build_idl(htp_iface.idl ${HTP_LIB}) set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 350ab9d966f..9811a07599f 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) { return q->capacity; } +// --------------------------------------------------------------------------- +// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth, +// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper +// transparently handles values that exceed the 16-bit limit and submits +// chained DMA transtions. +// +// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push. +// Case 2 (contiguous block): width == srcstride == dststride. Reshape the +// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a +// single descriptor, preserving async DMA behavior. +// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows +// one at a time. The first N-1 rows are pushed+popped synchronously; +// the last row is left async so the caller can pop it. +// --------------------------------------------------------------------------- +#define UDMA_MAX_FIELD_VAL 65535u + +static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) { + // Fast path: everything fits in 16 bits. + if (__builtin_expect( + width <= UDMA_MAX_FIELD_VAL && + nrows <= UDMA_MAX_FIELD_VAL && + src_stride <= UDMA_MAX_FIELD_VAL && + dst_stride <= UDMA_MAX_FIELD_VAL, 1)) { + return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows); + } + + // Case 2: contiguous block (width == src_stride == dst_stride). + // Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535. + if (width == src_stride && width == dst_stride) { + size_t total = width * nrows; + + // Pick the largest 128-byte-aligned sub_width that divides total evenly. + size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408 + while (sub_width > 0 && total % sub_width != 0) { + sub_width -= 128; + } + if (sub_width == 0) { + // Fallback: use original width (must fit) with adjusted nrows. + // This shouldn't happen for 128-aligned DMA sizes. + sub_width = width; + } + size_t sub_nrows = total / sub_width; + + // Handle sub_nrows > 65535 by issuing chunked descriptors. + const uint8_t *src = (const uint8_t *)dptr.src; + uint8_t *dst = (uint8_t *)dptr.dst; + size_t rows_done = 0; + while (rows_done < sub_nrows) { + size_t chunk = sub_nrows - rows_done; + if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL; + + dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width); + if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk)) + return false; + + rows_done += chunk; + // Complete all chunks without waiting except the last one, so the + // caller's single dma_queue_pop drains the final descriptor. + if (rows_done < sub_nrows) + dma_queue_pop_nowait(q); + } + return true; + } + + // Case 3: stride overflow — fall back to row-by-row. + { + const uint8_t *src = (const uint8_t *)dptr.src; + uint8_t *dst = (uint8_t *)dptr.dst; + for (size_t r = 0; r < nrows; ++r) { + dma_ptr p = dma_make_ptr(dst + r * dst_stride, + src + r * src_stride); + if (!dma_queue_push(q, p, 0, 0, width, 1)) + return false; + if (r + 1 < nrows) + dma_queue_pop_nowait(q); + } + return true; + } +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index fb8a25a3f20..8ed1456bc54 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } -static inline int32_t hex_is_aligned(void * addr, uint32_t align) { +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { return ((size_t) addr & (align - 1)) == 0; } +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { uint32_t left_off = (size_t) addr & (chunk_size - 1); uint32_t right_off = left_off + n; @@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { return m * ((n + m - 1) / m); } +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c new file mode 100644 index 00000000000..c703a049426 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -0,0 +1,1528 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hvx-dump.h" +#include "worker-pool.h" +#include "htp-ctx.h" +#include "htp-msg.h" + +#include "hmx-utils.h" +#include "hmx-ops.h" +#include "hmx-profile.h" + +static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128 in the tile. +// Column offset (n*4) is added at runtime. Only entries 0..15 are used (masked by predicate). +static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, + 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes +#define HMX_X4X2_SCALES_PER_BLK 8 +#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes + +static inline void swap_ptr(void **p1, void **p2) { + void *t = *p1; + *p1 = *p2; + *p2 = t; +} + +typedef struct { + uint8_t *dst; + const uint8_t *src; + dma_queue *dma; + size_t n_rows; + size_t src_stride; // DDR row stride (full row_stride) + size_t dst_stride; // VTCM sub-block row stride + size_t quant_off; // quant byte offset in each DDR row + size_t quant_width; // quant bytes to copy per row + size_t scale_off; // scale byte offset in each DDR row + size_t scale_width; // scale bytes to copy per row +} qweight_fetch_task_state_t; + +// Compute the byte stride of one row in x4x2 format. +// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because +// x4x2 packing has the same density as block_q4_0 / block_q8_0. +// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] +// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). +// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). +static inline size_t get_x4x2_row_stride(int weight_type, int k) { + int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q8_0: + return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + default: + return 0; + } +} + +// --- Overflow-safe arithmetic for VTCM budget calculation --- + +static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget. +// +// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead +// per_n_cost: bytes per nc column (weight + scratch buffers) +// per_m_cost: bytes per mc row (activation) +// per_mn_cost: bytes per mc*nc element (output) +// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.) +// +// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. +// Returns 0 on success, -1 if VTCM is insufficient. +static int hmx_compute_chunks( + size_t vtcm_total, size_t overhead, + size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost, + int m, int n, + size_t *m_chunk_out, size_t *n_chunk_out, + size_t *total_out) +{ + if (m <= 0 || n <= 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + size_t best_mn = 0, best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); + for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { + // Early exit: if nc * m_max cannot beat best, smaller nc won't either + if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn) + break; + + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); + mc = hex_smin(mc, (size_t)m); + + if (mc > 0 && mc * nc > best_mn) { + best_mn = mc * nc; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; + if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hmx_add_overflow(t0, t1, &total)) return -1; + if (hmx_add_overflow(total, t2, &total)) return -1; + if (hmx_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// forward declaration – defined after transfer_activation_chunk_fp32_to_fp16 +void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride); + +// Scatter row-major FP16 weight (already in VTCM scratch) directly into transposed [K][N] tiles. +// vtcm_src: [n_cols][k] row-major fp16 in VTCM scratch buffer +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst, + const __fp16 *restrict vtcm_src, + int n_cols, int k) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + for (int r = 0; r < n_cols; r += 2) { + int ct = r / HMX_FP16_TILE_N_ROWS; // N-dimension tile index + int local_r = r % HMX_FP16_TILE_N_ROWS; // intra-tile row index + const bool next_row_valid = (r + 1) < n_cols; + + // Offset vectors for N-columns local_r and local_r+1, reused across K-tiles. + HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) { + int kt = c / HMX_FP16_TILE_N_COLS; + int tile_idx = ct * n_k_tiles + kt; + __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS; + + HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c); + HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1); + } + } +} + +// --- x4x2 format dequantizers --- + +// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. +// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles +// of the same 32 packed bytes. +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( + const uint8_t *packed_32, bool upper_nibbles, + const __fp16 *scale, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + // Use standard vlut16 (not _nomatch) to avoid stale-register NaN. + // _nomatch retains the previous destination-register value for colliding + // indices, but the C intrinsic doesn't model the implicit read so the + // compiler may allocate a register containing garbage/NaN. + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using +// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. +// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. +static inline void dequantize_x4x2_q4_0_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt, + HVX_Vector out[4]) { + // Load all 128 packed bytes (4 contiguous 32-byte groups) + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + // Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16] + HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] + + // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1])); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3])); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter + out[0] = v_lo; // group0 already in [0:63] + out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63] + out[2] = v_hi; // group2 already in [0:63] + out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63] +} + +// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( + const int8_t *quants_32, const __fp16 *scale) { + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. +// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. +// Output: vtcm_dst in tile-major FP16 layout. +static void dequantize_x4x2_weight_to_fp16_tiles_task( + __fp16 *restrict vtcm_dst, + const uint8_t *restrict vtcm_src, + int n_cols, int k_block, + size_t row_stride, int weight_type, + int start_tile, int end_tile) { + + const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int qrow_size = is_q4 ? (k_block / 2) : k_block; + + const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) + ? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); + + // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. + // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 + // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. + const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) + + for (int t = start_tile; t < end_tile; ) { + int ct = t / n_k_tiles; // column tile index + int kt = t % n_k_tiles; // K tile index + + // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- + if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + int blk_idx = (kt * 32) / QK_Q4_0x4x2; + int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + + __fp16 *tile_bases[4]; + for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0[4], v1[4]; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + if (row1 < n_cols) { + dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1); + } else { + v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + } + + for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } + + t += 4; + continue; + } + + // --- Single-tile fallback --- + __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; + + if (is_q4) { + int blk_idx = (kt * 32) / QK_Q4_0x4x2; + int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; + bool upper = (sub_blk >= 4); + int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; // reset to column 0 + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx( + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } else { + // Q8_0 + int blk_idx = (kt * 32) / QK_Q8_0x4x2; + int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; + int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; // reset to column 0 + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx( + (const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q8_0_group_hvx( + (const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; + } + + // Drain HVX scatter write buffer: a vmem load on the same HW thread retires + // all pending scatter entries to VTCM. Without this, the main thread's HMX + // reads may see stale data because atomic_fetch_sub (release) only orders + // regular stores, not the HVX scatter buffer. + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; +} x4x2_dequantize_state_t; + +static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + + dequantize_x4x2_weight_to_fp16_tiles_task( + state->dst, state->src, state->n_cols, state->k_block, + state->row_stride, state->weight_type, start, end); + } +} + +static void dequantize_x4x2_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *vtcm_src, int n_cols, int k_block, + size_t row_stride, int weight_type) { + + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + assert(k_block % HMX_FP16_TILE_N_COLS == 0); + + int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + int n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + + x4x2_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + + worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); +} + +// --- End x4x2 dequantizers --- + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales, + int n_row_tiles, int n_col_tiles, int n_dot_tiles) { + hmx_set_output_scales(scales); + + for (int r = 0; r < n_row_tiles; ++r) { + for (int c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + + for (int k = 0; k < n_dot_tiles; ++k) { + int offset = k * HMX_FP16_TILE_N_ELMS; + hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; + hmx_consume_accumulator_fp16(out_tile); + } + } +} + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + #pragma unroll(4) + for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { + int c0 = c / HMX_FP16_TILE_N_COLS; + + const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS; + + HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0)); + volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + int n; // DDR row stride (total output columns) +} output_transfer_task_state_t; + +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + float *dst = st->dst + chunk_idx * st->n; + const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; + transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); + } +} + +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + + output_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.n = n; + + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); +} + +static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} + +static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} + +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} + +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mat_mul_permuted_w16a32(ctx, + hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); + } + } + return ret; +} + +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { + return -1; + } + + const int group_size = hmx_matmul_batch_r2(params); + + if (group_size <= 1) { + FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); + + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), + params->m, params->n, + &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); + + TIMER_START(total); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push_chained(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k); + swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, + n_row_tiles, n_col_tiles, params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + } + TIMER_STOP(output_store); + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + } + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); +#endif + + return 0; +} + +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (act_stride < k || weight_stride < k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vec_dot_size = k * sizeof(__fp16); + + // DMA-based activation gather for strided tensors (see batched path comment). + const bool use_dma_activation = (act_stride > k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + if (hmx_compute_chunks(vtcm_budget, + /*overhead=*/ 256, + /*per_n=*/ 3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/ sizeof(__fp16), // O + m, n, + &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * act_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) k * sizeof(float); + const size_t stride_bytes = (size_t) act_stride * sizeof(float); + dma_queue_push_chained(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_activation, + vtcm_f32_act, n_rows, k, k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_activation, + activation_chunk, n_rows, k, act_stride); + } + } + TIMER_STOP(activation_load); + + const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + // issue async DMA for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. + // The source rows can be strided (e.g. KV-cache K after ggml_permute). + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + + // issue async DMA for the next weight chunk (double buffering) + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + // interleave row-major fp16 from scratch into tile-major in vtcm_weight + interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k); + + swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + { + size_t weight_size = (size_t)k * n * sizeof(__fp16); + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} + +int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, + int k, int n, int w_type); + +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const uint8_t *restrict permuted_weight, int m, int k, int n, + int weight_type) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // for large m, k (e.g. prefill FFN Down), use out-stationary version + if (m >= 128 && k > n && n > 1024) { + FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)", + m, k, n, weight_type, (k + 511) / 512); + return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vec_dot_size = k * sizeof(__fp16); + const bool use_pipeline = (m >= 128) && (k <= n); + + // Select cost parameters based on execution path + size_t per_n_cost, per_mn_cost; + if (use_pipeline) { + per_n_cost = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + per_mn_cost = 2 * sizeof(__fp16); // O x 2 (output double buffer) + } else { + per_n_cost = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) + per_mn_cost = sizeof(__fp16); // O x 1 + } + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, + m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", + __func__, m, k, n, use_pipeline, vtcm_budget); + return -1; + } + + // Compute precise buffer sizes per execution path + const size_t weight_area_size = hex_align_up( + n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up( + m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + if (use_pipeline) { + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 + } else { + scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 + scratch1_size = scratch0_size; // x4x2 DMA buf 1 + scratch2_size = 0; // unused + } + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, weight_type, use_pipeline, + m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", + use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + if (!use_pipeline) { + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + // issue async DDR data transfer for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D + // because UDMA roiwidth is 16-bit and total size can exceed 65535. + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + + const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); + } + + // Dequant + vscatter writes directly to [K, N] transposed tiles. + // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + + swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + } + } else { + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // stage B and D (dequantize and store) are expected to be on the critical path + + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers + + // + // LD ||A3| | B3 || + // MM || C2 || + // ST || D1 | || + + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + } + + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } + + // prologue: B0, A1, C0, B1 + { + // B0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + } + + // C0 + core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + + // B1 + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); + } + } + + // main loop + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2} + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } + + // wait for HMX (C_{i}) -- C_{i} is done + + // result of B_{i+1} (input of C_{i+1}) should be ready now + + // issue C_{i+1} + if (i + 1 < n_chunk_cnt) { + core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + } + + // compute D_{i} + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + + // wait for DMA (A_{i+2}), compute B_{i+2} + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + } + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + if (!use_pipeline) { + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + size_t weight_size = (size_t)n * row_stride; + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} + +// C += AB +void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + + hmx_set_output_scales(col_scales); + + for (int i = 0; i < n_row_tiles; ++i) { + for (int j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + + __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + hmx_load_tile_pair_fp16(accum_tile, eye_tile); + } + + for (int k = 0; k < n_dot_tiles; ++k) { + int offset = k * HMX_FP16_TILE_N_ELMS; + hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + } + + hmx_consume_accumulator_fp16(accum_tile); + } + } +} + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, + int k_block, int k_stride) { + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool next_row_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + } +} + +void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); +} + +int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, + int k, int n, int weight_type) { + // Runtime check -- k >= 16384 exceeds 2D DMA limit + if (k >= 16384) { + FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k); + return -1; + } + // assume k % 32 == 0 && n % 32 == 0 + const size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + const size_t vtcm_budget = ctx->vtcm_scratch_size; + + const size_t M_BLOCK_SIZE = 512; + const size_t N_BLOCK_SIZE = 512; + const size_t K_BLOCK_SIZE = 512; + + // Compute precise buffer sizes + const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); + const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); + + const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; + if (total_vtcm > vtcm_budget) { + FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n); + return -1; + } + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); + uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); + uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); + __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + + FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", + __func__, m, k, n, weight_type, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + // initialize eye tile (32x32 identity matrix) + { + HVX_Vector v; + v = Q6_V_vzero(); + v = Q6_Vw_vinsert_VwR(v, 0x3c000000); + v = Q6_V_vror_VR(v, VLEN - 4); + v = Q6_Vw_vinsert_VwR(v, 0x00003c00); + for (int i = 0; i < 16; ++i) { + ((HVX_Vector *) vtcm_eye_tile)[i] = v; + v = Q6_V_vror_VR(v, VLEN - 8); + } + } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + TIMER_DEFINE(fetch); + TIMER_DEFINE(act_load); + TIMER_DEFINE(wt_dequant); + TIMER_DEFINE(core); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { + size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); + for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { + size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + + const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + + for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { + size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + + TIMER_START(fetch); + // fetch activation block into VTCM + { + const float *activation_block = x + mr * k + kk; + + dma_queue_push_chained(ctx->dma[0], + dma_make_ptr(vtcm_scratch1, activation_block), + k_blk_sz * sizeof(float), + k * sizeof(float), + k_blk_sz * sizeof(float), + m_blk_sz); + } + + // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + { + qweight_fetch_task_state_t s; + + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int blk_start = kk / QK_Q4_0x4x2; + const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + const int full_qrow = is_q4 ? (k / 2) : k; + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + + s.dst = vtcm_scratch0; + s.src = w + nc * row_stride; + s.n_rows = n_blk_sz; + s.src_stride = row_stride; + s.dst_stride = sub_row_stride; + s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2); + s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2); + s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE; + s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE; + + // 2D DMA: quants sub-range + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), + s.dst_stride, s.src_stride, s.quant_width, s.n_rows); + // 2D DMA: scales sub-range + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), + s.dst_stride, s.src_stride, s.scale_width, s.n_rows); + } + TIMER_STOP(fetch); + + TIMER_START(act_load); + // load activation block + { + dma_queue_pop(ctx->dma[0]); // wait for act DNA + transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); + } + TIMER_STOP(act_load); + + TIMER_START(wt_dequant); + // dequantize weight block + { + dma_queue_pop(ctx->dma[0]); + dma_queue_pop(ctx->dma[0]); + // vtcm_scratch0 is used to store the qweight chunk + // worker_pool_run_func already returned, so fetch is done + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, + n_blk_sz, k_blk_sz, sub_row_stride, weight_type); + } + TIMER_STOP(wt_dequant); + + // core mma + TIMER_START(core); + { + core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, + n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); + } + TIMER_STOP(core); + } + + // store output block + { + float *output_block = out + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", + TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); +#endif + return 0; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h new file mode 100644 index 00000000000..b36c8d129ba --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -0,0 +1,72 @@ +// HMX operation entry-point declarations. +// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_OPS_H +#define HMX_OPS_H + +#include +#include + +#ifndef restrict +# define restrict __restrict +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct htp_context; // forward declaration + +typedef struct { + float *dst; + const float *activation; + const __fp16 *permuted_weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_matmul_w16a32_batched_params_t; + +// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output +// act_stride: activation row stride in elements (= k for contiguous, or +// nb[1]/sizeof(float) for permuted tensors like attention Q). +// weight_stride: weight row stride in elements (= k for compact weights, or +// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const __fp16 *permuted_weight, + int m, int k, int n, + int act_stride, + int weight_stride); + +// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params); + +// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int weight_type); + +#ifdef __cplusplus +} +#endif + +#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-profile.h b/ggml/src/ggml-hexagon/htp/hmx-profile.h new file mode 100644 index 00000000000..01eece720c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-profile.h @@ -0,0 +1,34 @@ +// Conditional fine-grained profiling macros for HMX operations. +// +// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this +// header) to instrument sub-operation latencies with HAP qtimer. When the +// macro is not defined the TIMER_* helpers expand to nothing so there is zero +// overhead. +// +// Usage: +// TIMER_DEFINE(my_phase); // declare accumulator variable +// TIMER_START(my_phase); // snapshot start time +// ... work ... +// TIMER_STOP(my_phase); // accumulate elapsed ticks +// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase)); + +#ifndef HMX_PROFILE_H +#define HMX_PROFILE_H + +#include + +// #define ENABLE_PROFILE_TIMERS + +#if defined(ENABLE_PROFILE_TIMERS) +# define TIMER_DEFINE(name) int64_t name##_ticks = 0 +# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count() +# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0 +# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks) +#else +# define TIMER_DEFINE(name) +# define TIMER_START(name) +# define TIMER_STOP(name) +# define TIMER_US(name) 0LL +#endif + +#endif // HMX_PROFILE_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h new file mode 100644 index 00000000000..aacfbcda287 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -0,0 +1,88 @@ +// HMX tile-level inline helpers (FP16 32x32 tile operations). +// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_UTILS_H +#define HMX_UTILS_H + +#include +#include + +#define HMX_FP16_TILE_N_ROWS 32 +#define HMX_FP16_TILE_N_COLS 32 +#define HMX_FP16_TILE_N_ELMS 1024 +#define HMX_FP16_TILE_SIZE 2048 + +#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) + +static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) { + asm volatile("bias = mxmem2(%0)" :: "r"(scales)); +} + +// Initialise aligned 256-byte area with scale vector + zero padding. +static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + HVX_Vector *pv = (HVX_Vector *)out_scales; + *pv++ = v_scale; + *pv = Q6_V_vzero(); +} + +// Load multiple contiguous tiles with :deep streaming. +// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt]. +// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank +// boundary, otherwise the mxmem instruction will raise a precise bus error. +// Callers must ensure their VTCM layout satisfies this constraint. +static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles, + const __fp16 *col_tiles, + size_t n_tiles) { + size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1; + asm volatile( + "{ activation.hf = mxmem(%0, %1):deep\n" + "weight.hf = mxmem(%2, %3) }\n" + :: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit) + : "memory"); +} + +// Load a single activation+weight tile pair (no :deep streaming). +// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula +// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047. +// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation +// places a tile near a 4 MB bank boundary, the oversized region crosses it and +// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly +// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047). +static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile, + const __fp16 *wt_tile) { + asm volatile( + "{ activation.hf = mxmem(%0, %1)\n" + "weight.hf = mxmem(%2, %3) }\n" + :: "r"(act_tile), "r"(2047), + "r"(wt_tile), "r"(2047) + : "memory"); +} + +static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) { + // Use the combined convert-and-store instruction (matches the reference + // Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence + // "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter. + asm volatile( + "mxmem(%0, %1):after.hf = acc\n" + :: "r"(out), "r"(0) + : "memory"); +} + +// Compute inner product of two vectors of tiles and store result. +static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out, + const __fp16 *row_tiles, + const __fp16 *col_tiles, + size_t n_tiles) { + hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles); + hmx_consume_accumulator_fp16(out); +} + +// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index a707d98239c..a92acfa0a85 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -30,6 +30,12 @@ struct htp_context { atomic_bool vtcm_needs_release; uint32_t opmask; + + // HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX) +#ifdef HTP_HAS_HMX + int hmx_enabled; // Runtime flag: HMX initialisation succeeded + size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation) +#endif }; #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 56bc5b622c5..391148be0e9 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -32,13 +32,14 @@ enum htp_status { // Duplicated here because we can't include full ggml.h in the htp build. // We have some static_asserts in the cpp code to ensure things are in sync. enum htp_data_type { - HTP_TYPE_F32 = 0, - HTP_TYPE_F16 = 1, - HTP_TYPE_Q4_0 = 2, - HTP_TYPE_Q8_0 = 8, - HTP_TYPE_I32 = 26, - HTP_TYPE_I64 = 27, - HTP_TYPE_MXFP4 = 39, + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_IQ4_NL = 20, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, HTP_TYPE_COUNT }; @@ -87,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) { return QK4_0; case HTP_TYPE_Q8_0: return QK8_0; + case HTP_TYPE_IQ4_NL: + return QK4_NL; case HTP_TYPE_MXFP4: return QK_MXFP4; default: @@ -105,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) { return sizeof(block_q4_0); case HTP_TYPE_Q8_0: return sizeof(block_q8_0); + case HTP_TYPE_IQ4_NL: + return sizeof(block_iq4_nl); case HTP_TYPE_MXFP4: return sizeof(block_mxfp4); default: diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 9ebd937e46d..2dc716cb441 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -7,7 +7,7 @@ #include "remote.idl" interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); AEEResult enable_etm(); AEEResult disable_etm(); diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 3e6a8579b1f..db05ab40d28 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -9,6 +9,9 @@ #include "hex-utils.h" #include "hvx-types.h" +#define hvx_vmem(A) *((HVX_Vector *)(A)) +#define hvx_vmemu(A) *((HVX_UVector *)(A)) + static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) dst); @@ -112,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { return Q6_Q_and_QQ(p_exp, p_frac); } -static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { - const HVX_Vector zero = Q6_V_vsplat_R(0); +static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) { + const HVX_Vector zero = Q6_V_vzero(); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); - HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)); +} + +static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { + HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1)); #if __HVX_ARCH__ < 79 // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) @@ -128,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { return v; } +#if __HVX_ARCH__ >= 79 +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +#else +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +#endif + /* Q6_Vsf_equals_Vw is only available on v73+.*/ #if __HVX_ARCH__ < 73 static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 2a3f9e562b7..ef9cba8ecc1 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -25,6 +25,10 @@ #include "htp-ops.h" #include "worker-pool.h" +#ifdef HTP_HAS_HMX +#include "hmx-ops.h" +#endif // HTP_HAS_HMX + AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { struct htp_context * ctx; int err = 0; @@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) { } ctx->vtcm_inuse = true; + + + return 0; } @@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } +#ifdef HTP_HAS_HMX + if (use_hmx) { + ctx->vtcm_scratch_size = ctx->vtcm_size; + ctx->hmx_enabled = 1; + + FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size); + } else { + // HMX disabled: skip HMX initialisation so the + // dispatch loop falls through to the HVX compute paths. + ctx->hmx_enabled = 0; + ctx->vtcm_scratch_size = ctx->vtcm_size; + FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size); + } +#endif + qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; @@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { for (int i = 0; i < ctx->n_threads; i++) { dma_queue_delete(ctx->dma[i]); } +#ifdef HTP_HAS_HMX + if (ctx->hmx_enabled) { + ctx->hmx_enabled = 0; + } +#endif + vtcm_free(ctx); @@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c, struct dspqueue_buffer * bufs, size_t n_bufs, struct profile_data * prof) { - // Prep response struct + // Prep response struct (zero-init to clear cmp/unused union) struct htp_general_rsp rsp; + memset(&rsp, 0, sizeof(rsp)); rsp.op = op; rsp.status = status; rsp.prof_usecs = prof->usecs; @@ -1037,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); } +#ifdef HTP_HAS_HMX +// --------------------------------------------------------------------------- +// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad. +// VTCM, DMA and thread dispatch are managed inside the HMX kernels. +// --------------------------------------------------------------------------- + +static void proc_hmx_matmul_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + size_t n_bufs) { + // HMX weight tile requires N to be 32-aligned. + if (req->src0.ne[1] % 32 != 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 || + req->src1.ne[2] * req->src1.ne[3] > 1); + + // Quantised HMX kernels only handle flat 2D matmul (host already rejects + // batched quantised, but guard here too). F16 batched matmul is handled + // by the dedicated wrapper in hmx-matmul-ops.c. + if (is_batched && + req->src0.type != HTP_TYPE_F16) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + // HMX assumes contiguous row-major layout. Fall back for permuted + // tensors where strides are non-monotonic (e.g. transposed KV cache). + if (req->src0.nb[0] > req->src0.nb[1] || + req->src1.nb[0] > req->src1.nb[1]) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + // M alignment: when M > 32 but not 32-aligned, we split into + // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). + // When M <= 32 and not 32-aligned, fall back entirely to HVX. + const int m_total = (int) req->src1.ne[1]; + const int m_tail = m_total % 32; + const int m_hmx = m_total - m_tail; + + if (m_hmx == 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + // HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights. + // Other types (e.g. MXFP4) fall back to HVX. + { + uint32_t wtype = req->src0.type; + if (wtype != HTP_TYPE_F16 && + wtype != HTP_TYPE_Q4_0 && + wtype != HTP_TYPE_Q8_0 && + wtype != HTP_TYPE_IQ4_NL) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + // Quantised HMX path requires K aligned to 256 (x4x2 super-block). + // F16 HMX path requires K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + } + + (void) n_bufs; + + struct dspqueue_buffer rsp_bufs[1]; + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); + + // src0 = weights, src1 = activation, dst = output + void * wgt = (void *) bufs[0].ptr; + float * act = (float *) bufs[1].ptr; + float * dst = (float *) bufs[2].ptr; + + int k = (int) req->src0.ne[0]; // inner dimension + int n = (int) req->src0.ne[1]; // weight columns + + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + + // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + int ret = -1; + + const int ne02 = (int) req->src0.ne[2]; + const int ne03 = (int) req->src0.ne[3]; + const int ne12 = (int) req->src1.ne[2]; + const int ne13 = (int) req->src1.ne[3]; + // Row strides in elements. For compact tensors these equal k; for + // permuted attention views they can be larger, so pass the real stride. + const int act_stride = (int)(req->src1.nb[1] / sizeof(float)); + const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16)); + + switch (req->src0.type) { + case HTP_TYPE_F16: + if (is_batched) { + hmx_matmul_w16a32_batched_params_t batch_params = { + .dst = dst, + .activation = act, + .permuted_weight = (const __fp16 *) wgt, + .m = m_hmx, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = weight_stride, + .dst_stride = (int)(req->dst.nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = req->src0.nb[2], + .src0_nb3 = req->src0.nb[3], + .src1_nb2 = req->src1.nb[2], + .src1_nb3 = req->src1.nb[3], + .dst_nb2 = req->dst.nb[2], + .dst_nb3 = req->dst.nb[3], + }; + ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params); + } else { + ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act, + (const __fp16 *) wgt, + m_hmx, k, n, + act_stride, + weight_stride); + } + break; + default: + ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act, + (const uint8_t *) wgt, + m_hmx, k, n, (int) req->src0.type); + break; + } + + if (ret == 0) { + rsp_status = HTP_STATUS_OK; + } else { + FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); + vtcm_release(ctx); + req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE; + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + vtcm_release(ctx); + } + + // --- Phase 2: HVX on the remaining m_tail rows --- + if (m_tail > 0 && rsp_status == HTP_STATUS_OK) { + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; // weights: unchanged + octx.src1 = req->src1; + octx.src1.ne[1] = m_tail; // only tail rows + octx.dst = req->dst; + octx.dst.ne[1] = m_tail; // only tail rows + // Always re-quantize tail src1: HMX Phase 1 overwrites VTCM, + // so any previously cached quantized data (SKIP_QUANTIZE pipeline) + // is invalid. + octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE; + octx.op = req->op; + octx.n_threads = ctx->n_threads; + + // Offset activation and dst pointers past the HMX-processed rows. + // Use nb[1] (row stride in bytes) to compute the byte offset. + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]); + octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]); + + FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p", + m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data); + + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + uint32_t hvx_ret = op_matmul(&octx); + vtcm_release(ctx); + if (hvx_ret != HTP_STATUS_OK) { + FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret); + rsp_status = HTP_STATUS_INTERNAL_ERR; + } + } else { + rsp_status = HTP_STATUS_INTERNAL_ERR; + } + } + + profile_stop(&prof); + + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +#endif // HTP_HAS_HMX + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; @@ -1089,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { FARF(ERROR, "Bad matmul-req buffer list"); continue; } - proc_matmul_req(ctx, &req, bufs, n_bufs); +#ifdef HTP_HAS_HMX + if (ctx->hmx_enabled) { + proc_hmx_matmul_req(ctx, &req, bufs, n_bufs); + } else +#endif + { + proc_matmul_req(ctx, &req, bufs, n_bufs); + } break; case HTP_OP_MUL_MAT_ID: From e1cdce46c5e795932ad9bc1470c38a31cf1bd05c Mon Sep 17 00:00:00 2001 From: Rail Chabdarov Date: Thu, 19 Mar 2026 19:14:08 +0100 Subject: [PATCH 026/249] hip: Avoid compiler bug in RDNA code generation during debug builds on Windows (llama/20655) --- ggml/src/ggml-hip/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index c2357722629..f96c6e09a9b 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -129,6 +129,11 @@ endif() if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) + if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") + # CMake on Windows doesn't support the HIP language yet. + # Therefore we workaround debug build's failure on HIP backend this way. + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g") + endif() target_link_libraries(ggml-hip PRIVATE hip::device) else() set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) From 65d820a44a6c95b88ae121918202fea9b4ba0d10 Mon Sep 17 00:00:00 2001 From: Sundaram krishnan <104441812+sundaram123krishnan@users.noreply.github.com> Date: Fri, 20 Mar 2026 01:06:23 +0530 Subject: [PATCH 027/249] ggml: guard KleidiAI DOWNLOAD_EXTRACT_TIMESTAMP for cmake < 3.24 (llama/20767) --- ggml/src/ggml-cpu/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7c062a62995..1a1bbc9f2be 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -572,9 +572,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} - DOWNLOAD_EXTRACT_TIMESTAMP NEW URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} ) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW) + endif() if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28") FetchContent_Declare(KleidiAI_Download From 46dcb35aa38f10eb5e1eb6f7c2de071e928a60bc Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 20 Mar 2026 17:08:39 +0800 Subject: [PATCH 028/249] CANN: add BF16 support for core operators (llama/20152) * CANN: add BF16 support for core operators Add BF16 (bfloat16) type support to the CANN backend for the following operators: MUL_MAT, MUL_MAT_ID, GET_ROWS, SET_ROWS, CPY, CONT, and OUT_PROD. This enables BF16 models to run on Ascend NPUs. * CANN: skip NZ weight format for BF16 and add 310P compile guards NZ weight format conversion does not support BF16 tensors, skip it in set_tensor, get_alloc_size and mul_mat. Remove BF16 from MUL_MAT_ID and OUT_PROD as there are no BF16 use cases. Add #ifndef ASCEND_310P guards for all BF16 operator support since 310P does not support BF16. --- ggml/src/ggml-cann/aclnn_ops.cpp | 12 +++++++++--- ggml/src/ggml-cann/ggml-cann.cpp | 29 +++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 9b736636def..b45774dde34 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; // src ggml_tensor * src1 = dst->src[1]; // index - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 + || dst->type == GGML_TYPE_BF16); switch (src0->type) { + case GGML_TYPE_BF16: case GGML_TYPE_F16: case GGML_TYPE_F32: if (src0->type == dst->type) { @@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { break; } case GGML_TYPE_F16: + case GGML_TYPE_BF16: { acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); @@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); + src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, dst->type); @@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (weight_to_nz && is_matmul_weight(weight)) { + if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { switch (type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif ggml_cann_mat_mul_fp(ctx, dst); break; case GGML_TYPE_Q4_0: diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index a682746bb42..2f9c350789c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + if (weight_to_nz && tensor->type != GGML_TYPE_BF16 + && is_matmul_weight((const ggml_tensor *) tensor)) { GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); weight_format_to_nz(tensor, offset, ctx->device); @@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t if (ne0 % MATRIX_ROW_PADDING != 0) { size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + } else if (weight_to_nz && tensor->type != GGML_TYPE_BF16 + && is_matmul_weight((const ggml_tensor *) tensor)) { // NZ format weight are not support quantized yet. // If ND tensor transform to NZ, size may changed. int64_t shape[] = { tensor->ne[1], tensor->ne[0] }; @@ -2283,6 +2285,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_F16: case GGML_TYPE_F32: return true; @@ -2320,6 +2325,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_Q8_0: return true; default: @@ -2332,6 +2340,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; @@ -2341,20 +2352,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_CPY: { ggml_tensor * src = op->src[0]; +#ifdef ASCEND_310P if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) { - // only support F32 and F16. + // only support F32 and F16 on 310P. return false; } +#else + if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) || + (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) { + // only support F32, F16 and BF16. + return false; + } +#endif return true; } break; case GGML_OP_CONT: { - // TODO: support GGML_TYPE_BF16 switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; From 49b505bcc5c76b30584a26fcc2d1d6751bcc986c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 20 Mar 2026 06:17:15 -0500 Subject: [PATCH 029/249] vulkan: change gated_delta_net to shard a column across a subgroup (llama/20662) * vulkan: change gated_delta_net to shard a column across a subgroup This is based on https://github.com/ggml-org/llama.cpp/pull/20391, I used an LLM to port the CUDA code to Vulkan, and guided to it to make various fixes to work with Vulkan (e.g. handling different subgroup sizes, unknown mapping of subgroup to invocation id, using subgroupAdd optionally, etc.). This fixes a perf regression from the transposing of the values in memory (!20443). * vulkan: Spread columns across fewer lanes to reduce the number of workgroups --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 38 +++- .../vulkan-shaders/gated_delta_net.comp | 165 +++++++++++------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 140 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3e36435d166..566958b3a9d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) { {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, }; + const bool use_subgroup_reduce = device->subgroup_arithmetic; for (uint32_t si = 0; si < 3; si++) { + const uint32_t S_V = gdn_sizes[si]; + GGML_ASSERT(is_pow2(S_V)); + + uint32_t lanes_per_column; + if (S_V >= 128u && device->subgroup_clustered) { + lanes_per_column = 8u; + } else { + // Use largest power-of-two that divides both S_V and subgroup_size so that + // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0. + // This means we don't need extra bounds checking logic in the shader. + lanes_per_column = std::min(S_V, device->subgroup_size); + } + + const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + size_t gdn_len; + const void * gdn_data; + if (use_subgroup_reduce && need_clustered_shader) { + gdn_len = gated_delta_net_f32_len; + gdn_data = (const void *)gated_delta_net_f32_data; + } else if (use_subgroup_reduce) { + gdn_len = gated_delta_net_f32_nocluster_len; + gdn_data = (const void *)gated_delta_net_f32_nocluster_data; + } else { + gdn_len = gated_delta_net_f32_shmem_len; + gdn_data = (const void *)gated_delta_net_f32_shmem_data; + } + + const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column; + const std::array wg_denoms = {1u, 1u, cols_per_wg}; + for (uint32_t kda = 0; kda < 2; kda++) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], - gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, - "main", 7, sizeof(vk_op_gated_delta_net_push_constants), - {1, 1, 1}, {gdn_sizes[si], kda}, 1); + gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); } } } @@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, { H, n_seqs, 1u }); + pc, { H, n_seqs, S_v }); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index f008859b99d..5e9f8308c1d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -1,11 +1,25 @@ #version 450 #extension GL_EXT_control_flow_attributes : require - +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_CLUSTERED +#extension GL_KHR_shader_subgroup_clustered : enable +#endif +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0, +// so no bounds checking is needed. layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint KDA = 0; +layout(constant_id = 2) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 3) const uint LANES_PER_COLUMN = 32; + +const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN; +const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { uint H; @@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; -shared FLOAT_TYPE s_k[S_V]; -shared FLOAT_TYPE s_q[S_V]; -shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) +#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED +shared FLOAT_TYPE temp[SUBGROUP_SIZE]; + +// This does a reduction across groups of LANES_PER_COLUMN +FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) { + const uint lane = gl_SubgroupInvocationID; + temp[lane] = partial; + barrier(); + [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) { + FLOAT_TYPE other = temp[lane ^ s]; + barrier(); + temp[lane] += other; + barrier(); + } + const FLOAT_TYPE result = temp[lane]; + barrier(); + return result; +} +#endif + +// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant +FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) { + switch (LANES_PER_COLUMN) { + case 1u: + return partial; +#if USE_SUBGROUP_CLUSTERED + // Workaround for GLSL requiring a literal constant for the cluster size. + // The branches should all fold away. + case 2u: + return subgroupClusteredAdd(partial, 2u); + case 4u: + return subgroupClusteredAdd(partial, 4u); + case 8u: + return subgroupClusteredAdd(partial, 8u); + case 16u: + return subgroupClusteredAdd(partial, 16u); + case 32u: + return subgroupClusteredAdd(partial, 32u); + case 64u: + return subgroupClusteredAdd(partial, 64u); +#endif + default: +#if USE_SUBGROUP_ADD + return subgroupAdd(partial); +#else + return reduce_add_shmem(partial); +#endif + } +} void main() { const uint head_id = gl_WorkGroupID.x; - const uint seq_id = gl_WorkGroupID.y; - const uint col = gl_LocalInvocationID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN; + const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN); const uint iq1 = head_id % neq1; const uint iq3 = seq_id / rq3; @@ -42,9 +103,9 @@ void main() { const uint state_size = S_V * S_V; const uint state_base = (seq_id * H + head_id) * state_size; - FLOAT_TYPE state[S_V]; - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); + FLOAT_TYPE s_shard[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -53,76 +114,56 @@ void main() { const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; const uint k_off = q_off; const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; - - s_q[col] = FLOAT_TYPE(data_q[q_off + col]); - s_k[col] = FLOAT_TYPE(data_k[k_off + col]); - const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); - if (KDA != 0) { - const uint g_base = gb_off * S_V; - s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); + FLOAT_TYPE k_reg[ROWS_PER_LANE]; + FLOAT_TYPE q_reg[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = FLOAT_TYPE(data_k[k_off + i]); + q_reg[r] = FLOAT_TYPE(data_q[q_off + i]); } - barrier(); - - const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); - + FLOAT_TYPE g_exp[ROWS_PER_LANE]; if (KDA == 0) { const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); - - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - kv_col += dot( - vec4(state[i], state[i+1], state[i+2], state[i+3]), - vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) - ); - } - - FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; - - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = g_val * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; - - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + g_exp[r] = g_val; } - - data_dst[attn_off + col] = attn_col * scale; } else { - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - kv_col += dot(gv * sv, kv); + const uint g_base = gb_off * S_V; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i])); } + } + + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + FLOAT_TYPE kv_shard = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + kv_shard += g_exp[r] * s_shard[r] * k_reg[r]; + } + FLOAT_TYPE kv_col = reduce_partial(kv_shard); - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = gv * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } + FLOAT_TYPE attn_partial = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + FLOAT_TYPE attn_col = reduce_partial(attn_partial); + if (lane == 0) { data_dst[attn_off + col] = attn_col * scale; } attn_off += S_V * H; - barrier(); } - [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + col * S_V + i] = state[i]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index abd2a9c36fa..8186dba36f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -987,7 +987,9 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}})); + string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); From ca5d565dcdb677bad03f719b7c85bf5939f18e39 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Sat, 21 Mar 2026 04:41:45 +0530 Subject: [PATCH 030/249] ggml-cpu: add always_inline to tinyBLAS_PPC accumulator saves (llama/20791) Explicitly mark save_acc and add_save_Acc with always_inline in tinyBLAS_PPC. This ensures the compiler keeps MMA accumulator disassembly within kernel's register context, preventing un-necessary stask spills. Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c89e5076f26..63ceb635dea 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -3194,6 +3194,7 @@ class tinyBLAS_PPC { private: + __attribute__((always_inline)) inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -3204,6 +3205,7 @@ class tinyBLAS_PPC { } } + __attribute__((always_inline)) inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); From 22710fdb82e744a521b69214187d9889e244b404 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Sat, 21 Mar 2026 04:22:51 +0000 Subject: [PATCH 031/249] Add shader count for Intel Arc Pro B60 (llama/20818) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 566958b3a9d..221e6fa04e9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -16048,6 +16048,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) case 0xE20C: // B570 return 18; case 0xE20B: // B580 + case 0xE211: // Pro B60 return 20; default: return 0; From 5f3428219a79f5c24604d9d34a3a4a0cbbc1e212 Mon Sep 17 00:00:00 2001 From: y198 <90976397+y198nt@users.noreply.github.com> Date: Sat, 21 Mar 2026 20:59:43 +0700 Subject: [PATCH 032/249] fix(rpc): prevent division by zero in deserialize_tensor (llama/20712) rpc : prevent division by zero in deserialize_tensor When receiving an RPC message with a deprecated tensor type (e.g., type 4 or 5 where `blck_size == 0`), `ggml_row_size()` will trigger a division by zero (SIGFPE) and crash the rpc-server. This patch adds a simple validation check in `deserialize_tensor` to return `nullptr` if the requested tensor type has a block size of 0. (Note: This was originally reported via Security Advisory and maintainer suggested dropping a patch here). * style: remove trailing whitespace --- ggml/src/ggml-rpc/ggml-rpc.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c168..5d8defad209 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1162,12 +1162,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp return nullptr; } + // Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types) + if (ggml_blck_size((enum ggml_type)tensor->type) == 0) { + GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type); + return nullptr; + } + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type if (result == nullptr) { - GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type); + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type); return nullptr; } From 77b635e9c4f0e3a8fe0252f2197f61b37a62a22c Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 22 Mar 2026 14:19:35 +0530 Subject: [PATCH 033/249] Increase number of output elements per-thread block if the K-dimension is small (llama/20635) * Increase per-thread work if the K-dimension is small With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536. The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices. This change increases the number of output elements per block for such cases. * Limit this change to ncols_dst = 1 * tab to space --- ggml/src/ggml-cuda/mmvq.cu | 56 +++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43fd..024b3d8cf22 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; @@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q( template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst( switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } } break; case 2: { constexpr int c_ncols_dst = 2; From 69f0d907ee609091eaa3a552ccfd63c9390e55d6 Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Sun, 22 Mar 2026 03:05:51 -0700 Subject: [PATCH 034/249] ggml-cuda: native bf16 flash attention for vec kernel (llama/20525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: native bf16 flash attention for vec and tile kernels mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo * ggml-cuda: address code owner review feedback reverted tile kernel changes to avoid larger refactor * fix ci failures on turing and hip * fix bf16 vec kernel compile on hip v_dot2 platforms * add comments --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/CMakeLists.txt | 11 ++--- ggml/src/ggml-cuda/convert.cuh | 6 +++ ggml/src/ggml-cuda/fattn-common.cuh | 48 +++++++++++++++++++ ggml/src/ggml-cuda/fattn-vec.cuh | 26 +++++++--- ggml/src/ggml-cuda/fattn.cu | 16 +++++++ .../fattn-vec-instance-bf16-bf16.cu | 7 +++ .../fattn-vec-instance-bf16-f16.cu | 7 +++ .../fattn-vec-instance-bf16-q4_0.cu | 7 +++ .../fattn-vec-instance-bf16-q4_1.cu | 7 +++ .../fattn-vec-instance-bf16-q5_0.cu | 7 +++ .../fattn-vec-instance-bf16-q5_1.cu | 7 +++ .../fattn-vec-instance-bf16-q8_0.cu | 7 +++ .../fattn-vec-instance-f16-bf16.cu | 7 +++ .../fattn-vec-instance-q4_0-bf16.cu | 7 +++ .../fattn-vec-instance-q4_1-bf16.cu | 7 +++ .../fattn-vec-instance-q5_0-bf16.cu | 7 +++ .../fattn-vec-instance-q5_1-bf16.cu | 7 +++ .../fattn-vec-instance-q8_0-bf16.cu | 7 +++ .../template-instances/generate_cu_files.py | 2 +- ggml/src/ggml-hip/CMakeLists.txt | 11 ++--- ggml/src/ggml-musa/CMakeLists.txt | 11 ++--- 21 files changed, 197 insertions(+), 25 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e0..419862101d1 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f909..b8caeacf094 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,12 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { +#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e9abdf288c4..c59a4db3999 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16 + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e5..f0bd42a5761 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496f..a25a890db6d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 00000000000..3a2fa99b05b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 00000000000..60f0f6f7952 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 00000000000..489e05f08c3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 00000000000..6fa3c26d309 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 00000000000..421027fb29d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 00000000000..abbc9434802 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 00000000000..d641f859d81 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 00000000000..d1071dc2438 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 00000000000..8afda314238 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 00000000000..506864ac18d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 00000000000..0bbda8371e6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 00000000000..79be24daf9e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 00000000000..45636e5e70c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae20..3b5ab12fc40 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f96c6e09a9b..291b4837455 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977f..cc53c812ce5 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) From 1d0f0285de5575194a9c42450a1c5293cf433b51 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sun, 22 Mar 2026 22:06:27 +0800 Subject: [PATCH 035/249] support bf16 and quantized type (llama/20803) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ec1421841b..456b1699fa3 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (a->ne[3] != b->ne[3]) { return false; } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M - ) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { - return false; - } - } + ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16 ) { - // TODO: support GGML_TYPE_BF16 - // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added - return false; - } // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && From 607c92430f5ba00db92f18e1d4097de2212b9d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 22 Mar 2026 17:53:33 +0100 Subject: [PATCH 036/249] CUDA: fix BF16 FA compilation (llama/20865) --- ggml/src/ggml-cuda/convert.cuh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index b8caeacf094..f5d37c7b998 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -42,11 +42,15 @@ template } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); } else if constexpr(std::is_same_v && std::is_same_v) { -#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#ifdef GGML_USE_HIP + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#else +#if __CUDA_ARCH__ >= 800 return __bfloat1622float2(x); #else - return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); -#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); +#endif // __CUDA_ARCH__ >= 800 +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP From c976b22d7bf197ab8a727a99c42d58472ba144b0 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sun, 22 Mar 2026 22:45:11 -0700 Subject: [PATCH 037/249] opencl: add flattened Q4_K mv and general Q4_K mm (llama/20773) --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 289 ++++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 67 ++++ .../kernels/mul_mm_q4_k_f32_l4_lm.cl | 179 +++++++++++ .../kernels/mul_mv_q4_k_f32_flat.cl | 196 ++++++++++++ 5 files changed, 733 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 1f8250934b0..ae667b12d17 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -89,6 +89,7 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32 mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 + mul_mv_q4_k_f32_flat mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -107,6 +108,7 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q4_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e1dca6b4b4d..c984e59b6b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -534,11 +534,13 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; + cl_kernel kernel_mul_mv_q4_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -578,6 +580,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -917,6 +920,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); GGML_LOG_CONT("."); @@ -1209,6 +1214,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1482,6 +1504,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q4_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q6_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3347,6 +3386,40 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_q4_K { + // Quantized values + cl_mem q = nullptr; + // Scales for each super block. + cl_mem s = nullptr; + // Scales + cl_mem d = nullptr; + // Min + cl_mem dm = nullptr; + + ~ggml_tensor_extra_cl_q4_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + } +}; + struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; @@ -3956,6 +4029,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { + delete e; + } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + delete e; + } for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) { delete e; } @@ -4039,6 +4118,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { + ggml_tensor_extra_cl_q4_K * extra; + if (temp_tensor_extras_q4_K.empty()) { + extra = new ggml_tensor_extra_cl_q4_K(); + } else { + extra = temp_tensor_extras_q4_K.back(); + temp_tensor_extras_q4_K.pop_back(); + } + + temp_tensor_extras_q4_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { ggml_tensor_extra_cl_q6_K * extra; if (temp_tensor_extras_q6_K.empty()) { @@ -4080,6 +4174,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q8_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + temp_tensor_extras_q4_K.push_back(e); + } + temp_tensor_extras_q4_K_in_use.clear(); + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { temp_tensor_extras_q6_K.push_back(e); } @@ -4101,6 +4200,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_q4_K; + std::vector temp_tensor_extras_q4_K_in_use; std::vector temp_tensor_extras_q6_K; std::vector temp_tensor_extras_q6_K_in_use; @@ -4835,6 +4936,83 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_K(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3 * ggml_blck_size(tensor->type) / 64); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_dm + size_s + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5245,6 +5423,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; @@ -9357,6 +9563,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif @@ -10005,6 +10212,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q4_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q6_K: { if (ne11 < 32) { break; @@ -10449,6 +10700,43 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q4_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_q4_K_f32; if (backend_ctx->gpu_family == INTEL) { @@ -10482,6 +10770,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; } case GGML_TYPE_Q5_K: diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 78ef9c177f6..272d0ea23f0 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -28,6 +28,7 @@ #define QK8_0 32 #define QR8_0 1 #define QK_K 256 +#define K_SCALE_SIZE (3 * QK_K / 64) #define K_QUANTS_PER_ITERATION 2 typedef char int8_t; @@ -55,6 +56,16 @@ struct block_q4_1 { uchar qs[QK4_1 / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q4_k +//------------------------------------------------------------------------------ +struct block_q4_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar q[QK_K / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -408,6 +419,62 @@ kernel void kernel_restore_block_q8_0_trans( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_K +// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_K( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q4_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q4_K( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->q[i] = q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl new file mode 100644 index 00000000000..2235b1ae838 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl @@ -0,0 +1,179 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + char * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 >> (b * 4))&0x0F, (q.s1 >> (b * 4))&0x0F, (q.s2 >> (b * 4))&0x0F, (q.s3 >> (b * 4))&0x0F)))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl new file mode 100644 index 00000000000..d92fb968904 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl @@ -0,0 +1,196 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q4K_SIZE 144 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32_flat( + global uchar * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q4K_SIZE; + uint blk = nb01 / BLOCK_Q4K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} From a0e41ec26111e76755d2dd8e912e0c88f4582135 Mon Sep 17 00:00:00 2001 From: Dan Hoffman <43101339+thedanhoffman@users.noreply.github.com> Date: Sun, 22 Mar 2026 23:05:37 -0700 Subject: [PATCH 038/249] fix(openvino): explicit memset in buffer_context allocation (llama/20857) * fix(openvino): explicit memset in buffer_context allocation * minor --------- Co-authored-by: Dan Hoffman Co-authored-by: Georgi Gerganov --- ggml/src/ggml-openvino/ggml-openvino.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0031cb7369f..b3058b4af73 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -97,6 +97,8 @@ struct ggml_backend_openvino_buffer_context { ov_buffer = std::make_shared(std::move(usm_tensor)); } else { data = ggml_aligned_malloc(size); + GGML_ASSERT(data); + memset(data, 0, size); ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); } From 54f5c02f29a6d5f20f94d4707dded5fcc0b2fdb0 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Mon, 23 Mar 2026 15:24:06 +0800 Subject: [PATCH 039/249] CANN: add RoPE cache preload before ACL graph capture (llama/20747) ACL graph capture disallows host-to-device memcpy and device memory malloc/free on the captured stream. Pre-load the RoPE cache before capture so that: - Host-to-device copies and allocations run on the non-captured stream - Cache metadata is populated and memory pool is warmed up - During capture, only on-device computations are recorded; host-side and allocation branches are skipped --- ggml/src/ggml-cann/aclnn_ops.cpp | 52 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 15 +++++++++ ggml/src/ggml-cann/common.h | 2 +- ggml/src/ggml-cann/ggml-cann.cpp | 13 ++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index b45774dde34..adb4d68e868 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3011,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } } +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + GGML_TENSOR_UNARY_OP_LOCALS + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_imrope || mrope_used) { + is_neox = true; + } + + int64_t rope_dims = n_dims; + if (is_vision) { + rope_dims = src0->ne[0]; + } + + // Run the full cache init on the non-captured stream. This performs all + // host-to-device memcpy, aclrtMalloc/Free, and on-device computations + // so that the memory pool is warmed up and cache metadata is populated. + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, + mrope_used, is_imrope, is_vision, rope_dims); + + // Reset `cached` so that during graph capture the on-device computations + // (sin/cos, position multiply, repeat, etc.) still execute and get recorded + // into the captured graph. The cache metadata (theta_scale_length, + // theta_scale, sections, position_length, etc.) remains set, which causes + // all host-to-device copy and malloc/free branches to be skipped. + ctx.rope_cache.cached = false; +} + void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 3effa1c289c..7f5ba4d3302 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Pre-load the RoPE cache before ACL graph capture. + * + * This function must be called outside of graph capture to perform + * host-to-device memory copies and device memory allocations that are + * not allowed on a captured stream. After pre-loading, the rope cache + * metadata is updated so that the subsequent call to + * aclnn_rope_cache_init (inside graph capture) skips these operations + * and only records the on-device computations into the captured graph. + * + * @param ctx CANN backend context. + * @param dst A ROPE destination tensor from the computation graph. + */ +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the index of the maximum value along the specified dimension * of a ggml tensor using the CANN backend. diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 0120f0dfd1e..5f960548cd2 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -277,7 +277,7 @@ struct ggml_graph_node_properties { } } - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { + if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){ return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } return true; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 2f9c350789c..6f26e91e046 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2225,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, // If no matching graph is found, add a new ACL graph. ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); cann_ctx->graph_lru_cache.push(new_graph); + + // Pre-load rope cache before graph capture. During capture the + // stream cannot perform host-to-device memcpy or device memory + // malloc/free. Running the full cache init now populates the + // cache metadata so these branches are skipped during capture, + // while also warming up the memory pool. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->op == GGML_OP_ROPE) { + ggml_cann_rope_cache_preload(*cann_ctx, node); + break; + } + } } } #else From c589dd77d4fe9df5ea6f1072d7913e870ed4da10 Mon Sep 17 00:00:00 2001 From: Rashid Ul Islam <33536561+Ra5hidIslam@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:15:34 +0530 Subject: [PATCH 040/249] metal: add CONV_3D (llama/19927) * Apply suggestions from code review Co-authored-by: Georgi Gerganov * metal:add conv_3d backend Rebased with master and resolved conflicts. * Resolved issues related to changes in variable names * kernel void kernel_upscale_bilinear_f32 was missing in my branch, added back, should pass all tests now --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 22 ++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 5 ++ ggml/src/ggml-metal/ggml-metal-impl.h | 36 +++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 75 ++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 92 +++++++++++++++++++++++ 7 files changed, 232 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 72ad876d5e4..9162342ee98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1748,6 +1748,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_3D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index fd2b3ddeb55..de43f819312 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -148,6 +148,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 82101f4714e..14144aab087 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1077,6 +1077,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_3D: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53437b23cda..ea471090cd8 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -643,6 +643,42 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct { + int32_t IW; + int32_t IH; + int32_t ID; + int32_t OW; + int32_t OH; + int32_t OD; + int32_t KW; + int32_t KH; + int32_t KD; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t p0; + int32_t p1; + int32_t p2; + int32_t d0; + int32_t d1; + int32_t d2; + int32_t IC; + int32_t N; + int32_t OC; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_conv_3d; + typedef struct{ int32_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c0bcad392b9..3cda21be43e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); } break; + case GGML_OP_CONV_3D: + { + n_fuse = ggml_metal_op_conv_3d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + // 1. Extract standard dimensions and byte strides + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // 2. Extract hyperparams from op_params + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t s2 = ((const int32_t *)(op->op_params))[2]; + const int32_t p0 = ((const int32_t *)(op->op_params))[3]; + const int32_t p1 = ((const int32_t *)(op->op_params))[4]; + const int32_t p2 = ((const int32_t *)(op->op_params))[5]; + const int32_t d0 = ((const int32_t *)(op->op_params))[6]; + const int32_t d1 = ((const int32_t *)(op->op_params))[7]; + const int32_t d2 = ((const int32_t *)(op->op_params))[8]; + const int32_t IC = ((const int32_t *)(op->op_params))[9]; + const int32_t N = ((const int32_t *)(op->op_params))[10]; + const int32_t OC = ((const int32_t *)(op->op_params))[11]; + + // 3. Build the parameter struct using the macro-generated variables + ggml_metal_kargs_conv_3d args = { + /*.IW =*/ (int32_t)op->src[1]->ne[0], + /*.IH =*/ (int32_t)op->src[1]->ne[1], + /*.ID =*/ (int32_t)op->src[1]->ne[2], + /*.OW =*/ (int32_t)op->ne[0], + /*.OH =*/ (int32_t)op->ne[1], + /*.OD =*/ (int32_t)op->ne[2], + /*.KW =*/ (int32_t)op->src[0]->ne[0], + /*.KH =*/ (int32_t)op->src[0]->ne[1], + /*.KD =*/ (int32_t)op->src[0]->ne[2], + s0, s1, s2, + p0, p1, p2, + d0, d1, d2, + IC, N, OC, + nb00, nb01, nb02, nb03, // Weight strides + nb10, nb11, nb12, nb13, // Input strides + nb0, nb1, nb2, nb3 // Output strides + }; + + // 4. Fetch the JIT pipeline + auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op); + + // 5. Grid mapping + int nth0 = 32; // Standard SIMD width for Apple Silicon + int nth1 = 1; + int nth2 = 1; + + int64_t spatial_volume = args.OW * args.OH * args.OD; + + int ntg0 = (spatial_volume + nth0 - 1) / nth0; + int ntg1 = args.OC; + int ntg2 = args.N; + + // 6. Bind and Dispatch via the ggml C wrapper + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2); + + return 1; +} + int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 019f2fec9ed..50e3c5c77a1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -75,6 +75,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b2328605dd9..9c6b1c4f62b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4883,6 +4883,98 @@ kernel void kernel_upscale_bilinear_f32( } } +template +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, // Weights [IC * OC, KD, KH, KW] + device const char * src1, // Inputs [IC * N, ID, IH, IW] + device char * dst, // Outputs [OC * N, OD, OH, OW] + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + + // 1. Un-flatten the spatial dimension from Grid X + int64_t spatial_idx = tgpig.x * 32 + tpitg.x; + + if (spatial_idx >= args.OW * args.OH * args.OD) { + return; // Thread falls outside the spatial volume + } + + int64_t od = spatial_idx / (args.OW * args.OH); + int64_t oh = (spatial_idx / args.OW) % args.OH; + int64_t ow = spatial_idx % args.OW; + + // 2. Map Y to Channels, Z to Batch + int64_t oc = tgpig.y; + int64_t batch_idx = tgpig.z; + + // 3. Calculate anchor coordinates in the Input volume + int64_t i_w_base = ow * args.s0 - args.p0; + int64_t i_h_base = oh * args.s1 - args.p1; + int64_t i_d_base = od * args.s2 - args.p2; + + float sum = 0.0f; + + // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width) + for (int64_t ic = 0; ic < args.IC; ++ic) { + + // ggml packs batch and channel together in the 4th dimension + int64_t src_cn_idx = batch_idx * args.IC + ic; + int64_t w_cn_idx = oc * args.IC + ic; + + for (int64_t kz = 0; kz < args.KD; ++kz) { + int64_t id = i_d_base + kz * args.d2; + if (id < 0 || id >= args.ID) continue; // Boundary check (Padding) + + for (int64_t ky = 0; ky < args.KH; ++ky) { + int64_t ih = i_h_base + ky * args.d1; + if (ih < 0 || ih >= args.IH) continue; + + for (int64_t kx = 0; kx < args.KW; ++kx) { + int64_t iw = i_w_base + kx * args.d0; + if (iw < 0 || iw >= args.IW) continue; + + // Convert multi-dimensional coordinates to flat byte offsets + int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03; + int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13; + + // Dereference memory and cast weights to f32 if they were f16 + float w_val = (float)*(device const T*)((device const char*)src0 + w_idx); + float i_val = *(device const float*)((device const char*)src1 + i_idx); + + sum += w_val * i_val; + } + } + } + } + + // 5. Write the accumulated value out to RAM + int64_t dst_cn_idx = batch_idx * args.OC + oc; + int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3; + + *(device float*)(dst + d_idx) = sum; +} + +// Explicit instantiations so the JIT compiler can find them by name +template [[host_name("kernel_conv_3d_f32_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +// Explicit instantiation for f16 weights +template [[host_name("kernel_conv_3d_f16_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + + static inline float bicubic_weight1(float x) { const float a = -0.75f; return ((a + 2) * x - (a + 3)) * x * x + 1; From 37c0a52c1bf4cd2bd32eca4cfce26fa667ae5736 Mon Sep 17 00:00:00 2001 From: las7 <98077186+las7@users.noreply.github.com> Date: Mon, 23 Mar 2026 10:54:57 -0700 Subject: [PATCH 041/249] rpc : RCE patch (llama/20908) --- ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 5d8defad209..0ed2c0dce60 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1443,7 +1443,9 @@ ggml_tensor * rpc_server::create_node(uint64_t id, const rpc_tensor * tensor = it_ptr->second; struct ggml_tensor * result = deserialize_tensor(ctx, tensor); - if (result == nullptr) { + if (result == nullptr || result->buffer == nullptr) { + GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n", + __func__, result == nullptr ? "tensor" : "buffer", id); return nullptr; } tensor_map[id] = result; From 624be93425126001fdcaf830be9e0b719705c4b9 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 23 Mar 2026 12:44:18 -0700 Subject: [PATCH 042/249] opencl: add q6_K gemm and gemv kernels for Adreno (llama/20089) * opencl: add q6_K noshuffle kernels, initial q6_K gemv, some host code * opencl: add q6_K transpose * opencl: fix cvt kernel name * opencl: add call to q6_K gemv * opencl: fix q6_K scale transpose * opencl: fix loading for gemv q6_K, refactor * opencl: fix transpose_8_buf kernel assignment, refactor * opencl: refactor q6_K transpose * opencl: add gemm_noshuffle_q6_k_f32 * opencl: fix qh loading * opencl: refactor q6_K gemv host side, release bufs and imgs * opencl: refactor * opencl: fix q6_K dequant and scale selection * opencl: workaround compiler bug, fix dump_tensor * opencl: refactor q6_K convert kernels * opencl: unpack transformed q6_K in get_tensor * opencl: refactor, handle non-uniform workgroups * opencl: support non-vector subgroup bcast --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 397 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 128 +++++- .../kernels/gemm_noshuffle_q6_k_f32.cl | 140 ++++++ .../kernels/gemv_noshuffle_q6_k_f32.cl | 293 +++++++++++++ 5 files changed, 920 insertions(+), 40 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ae667b12d17..af29f3b8f4c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -114,6 +114,8 @@ set(GGML_OPENCL_KERNELS gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q6_k_f32 + gemm_noshuffle_q6_k_f32 mul neg norm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c984e59b6b4..4dddcd82cfa 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -529,6 +529,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; + cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; @@ -716,6 +717,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q6_K_f32; + cl_kernel kernel_gemm_noshuffle_q6_K_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -924,6 +927,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err)); GGML_LOG_CONT("."); } @@ -2642,6 +2647,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err)); GGML_LOG_CONT("."); } + + // gemv_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl"); +#endif + + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); } @@ -5029,61 +5073,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, "Incorrect tensor size"); cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); - CL_CHECK(clEnqueueWriteBuffer( - queue, data_device, CL_TRUE, 0, - ggml_nbytes(tensor), data, 0, NULL, NULL)); + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); cl_buffer_region region; // Subbuffer for ql region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_ql; - extra->ql = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); auto previous_origin = region.origin; // Subbuffer for qh region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); region.size = size_qh; - extra->qh = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); previous_origin = region.origin; // Subbuffer for scales region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); region.size = size_s; - extra->s = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); previous_origin = region.origin; // Create subbuffer for d. region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); region.size = size_d; - extra->d = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); previous_origin = region.origin; // Flatten the weights - cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + cl_kernel kernel; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + kernel = backend_ctx->kernel_convert_block_q6_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; + } +#else + kernel = backend_ctx->kernel_convert_block_q6_K; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + cl_uchar mask = 0xff; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; size_t local_work_size[] = {64, 1, 1}; cl_event evt; @@ -5097,6 +5138,29 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, extra->size_d = size_d; tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + // Transpose ql as ushort + transpose_2d_as_16b(backend_ctx, + extra->ql, extra->ql, size_ql, K/4, M); + + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, + extra->qh, extra->qh, size_qh, K/4, M); + + // Transpose s as ushort + transpose_2d_as_16b(backend_ctx, + extra->s, extra->s, size_s, K/16/2, M); + + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, + extra->d, extra->d, size_d, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } #endif // GGML_OPENCL_SOA_Q @@ -5454,19 +5518,78 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_ql; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_s; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_ql.allocate(backend_ctx->context, size_ql); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_s.allocate(backend_ctx->context, size_s); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose ql, qh, s and d back + transpose_2d_as_16b(backend_ctx, extra->ql, buf_trans_ql.buffer, size_ql, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->s, buf_trans_s.buffer, size_s, M, K/16/2); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + + // unpack + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_ql.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_s.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); - - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; size_t local_work_size[] = {1, 1, 1}; cl_event evt; @@ -5759,6 +5882,8 @@ typedef struct { static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +#define QK_MXFP4 32 + #include #ifdef __cplusplus #include "half.hpp" @@ -5802,7 +5927,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso buf_d = malloc(size_e); CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); - CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->e, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); } else { // Read out the tensor from GPU memory. @@ -9537,6 +9662,196 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_buffer_region region; + cl_image_format img_fmt; + cl_image_desc img_desc; + + // subbuffer and image for activation + if (ne1 == 1) { + cl_mem ql_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buffer = nullptr; + cl_mem b_img = nullptr; + + // image for ql + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->ql; + CL_CHECK((ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buffer; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q6_K_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &ql_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(ql_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buffer)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf; + cl_mem b_buf_trans; + cl_mem b_img; + cl_mem b_img_trans; + + // subbuffer for activation + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = ne1 % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activation + region.origin = 0; + region.size = ne00 * (ne1 + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * (ne1 + padding) / 4; + img_desc.buffer = b_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activation + int height_B = ne1/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = ne00/4; + int padded_height_B = (ne1 + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + size_t global_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q6_K_f32; + int padded_N = ne1 + padding; + + cl_ushort mask_f000 = 0xF000; + cl_uchar mask_c0 = 0xC0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ushort),&mask_f000)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_c0)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {2, 128, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9673,6 +9988,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q6_K x fp32 + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 272d0ea23f0..34930dfbe6a 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -486,8 +486,13 @@ kernel void kernel_convert_block_q6_K( global uchar * dst_ql, global uchar * dst_qh, global char * dst_s, - global half * dst_d + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk ) { + if (get_global_id(0) >= n_blk) { + return; + } global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); @@ -514,8 +519,13 @@ kernel void kernel_restore_block_q6_K( global uchar * dst_qh, global char * dst_s, global half * dst_d, - global struct block_q6_K * dst + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk ) { + if (get_global_id(0) >= n_blk) { + return; + } global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); @@ -534,3 +544,117 @@ kernel void kernel_restore_block_q6_K( b->scales[i] = s[i]; } } + +kernel void kernel_convert_block_q6_K_noshuffle( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = b->ql[i*2 + 0] & mask_lsb_8; + uchar x1 = b->ql[i*2 + 1] & mask_lsb_8; + ql[i + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + ql[i + 32] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = b->ql[i*2 + 0 + 64] & mask_lsb_8; + uchar x3 = b->ql[i*2 + 1 + 64] & mask_lsb_8; + ql[i + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + ql[i + 96] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = b->qh[i*4 + 0] & mask_lsb_8; + uchar x1 = b->qh[i*4 + 1] & mask_lsb_8; + uchar x2 = b->qh[i*4 + 2] & mask_lsb_8; + uchar x3 = b->qh[i*4 + 3] & mask_lsb_8; + qh[i + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + qh[i + 8] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + qh[i + 16] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + qh[i + 24] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = b->qh[i*4 + 0 + 32] & mask_lsb_8; + uchar x5 = b->qh[i*4 + 1 + 32] & mask_lsb_8; + uchar x6 = b->qh[i*4 + 2 + 32] & mask_lsb_8; + uchar x7 = b->qh[i*4 + 3 + 32] & mask_lsb_8; + qh[i + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + qh[i + 40] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + qh[i + 48] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + qh[i + 56] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_K_noshuffle( + global uchar * src_ql, + global uchar * src_qh, + global char * src_s, + global half * src_d, + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) src_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) src_s + QK_K/16*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = ql[i + 0] & mask_lsb_8; + uchar x1 = ql[i + 32] & mask_lsb_8; + b->ql[i*2 + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + b->ql[i*2 + 1] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = ql[i + 64] & mask_lsb_8; + uchar x3 = ql[i + 96] & mask_lsb_8; + b->ql[i*2 + 0 + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + b->ql[i*2 + 1 + 64] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = qh[i + 0] & mask_lsb_8; + uchar x1 = qh[i + 8] & mask_lsb_8; + uchar x2 = qh[i + 16] & mask_lsb_8; + uchar x3 = qh[i + 24] & mask_lsb_8; + b->qh[i*4 + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + b->qh[i*4 + 1] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + b->qh[i*4 + 2] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + b->qh[i*4 + 3] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = qh[i + 0 + 32] & mask_lsb_8; + uchar x5 = qh[i + 8 + 32] & mask_lsb_8; + uchar x6 = qh[i + 16 + 32] & mask_lsb_8; + uchar x7 = qh[i + 24 + 32] & mask_lsb_8; + b->qh[i*4 + 0 + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + b->qh[i*4 + 1 + 32] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + b->qh[i*4 + 2 + 32] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + b->qh[i*4 + 3 + 32] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..3a9c624508a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl @@ -0,0 +1,140 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q6_K_f32( + global const ushort * src0_ql, + global const uchar * src0_qh, + global const ushort * src0_s, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + ushort mask_f000, + uchar mask_c0 +) { + dst = (global float *)( (global char *)dst + offsetd ); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); // n + int gx = get_global_id(1); // m + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * ptr_ql = src0_ql + gx_2; + global const uchar * ptr_qh = src0_qh + gx_2; + global const ushort * ptr_s = src0_s + gx_2; + global const half * ptr_d = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + // load 4x elements (ushort) of ql on M, each ushort contains 4 weights + // 4x ushort correspons to 4 rows on M + ushort4 bits4 = vload4(0, ptr_ql + (i/4)*m); // ql packed in 4s in ushort + uchar4 bits2 = vload4(0, ptr_qh + (i/4)*m); // qh packed in 4s in uchar + + // load 4 consecutive scales + char8 scale_s_8 = as_char8(vload4(0, ptr_s + (i/16/2)*m)); // 1 char scale every 16 elements, packed in 2s + char4 scale_s = ((i/16) % 2) == 0 ? scale_s_8.s0246 : scale_s_8.s1357; // transposed as ushort, 2 blocks + half4 scale_d = vload4(0, ptr_d + (i/256)*m); // 1 half scale every 256 elements + + // j=0 + // load 2x 4 elements of activations on N, corresponding to 8 rows on N + B.s0123 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 1); + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x0F00) >> 8) | (bits2.s0 & 0x30))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x0F00) >> 8) | (bits2.s1 & 0x30))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x0F00) >> 8) | (bits2.s2 & 0x30))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x0F00) >> 8) | (bits2.s3 & 0x30))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & mask_f000) >> 12) | ((bits2.s0 & mask_c0) >> 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & mask_f000) >> 12) | ((bits2.s1 & mask_c0) >> 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & mask_f000) >> 12) | ((bits2.s2 & mask_c0) >> 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & mask_f000) >> 12) | ((bits2.s3 & mask_c0) >> 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..6f89cf968b9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl @@ -0,0 +1,293 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantize_block_acc_bcast_8_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_8_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_1_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + +#define dequantize_block_acc_bcast_1_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + +#if defined(ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q6_K_f32( + read_only image1d_buffer_t src0_ql, + read_only image1d_buffer_t src0_qh, + global half2 * src0_s, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01 +) { + int grp = get_local_id(1); + int gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + int nb = ne00 / 32; + + uint4 reg_a_l; + ushort4 reg_a_h; + half2 reg_d; + char4 reg_s; + float8 reg_b; + + float2 total_sum = 0.0f; + + int line_stride_a = ne01 / 2; + int block_stride_a = NSUBGROUPS * ne01; + + for (int k = grp; k < nb; k += NSUBGROUPS) { + reg_d = src0_d[gid + k/8 * line_stride_a]; + reg_s = as_char4(src0_s[gid + k * line_stride_a]); + + if (slid < 4) { + reg_b.s0123 = read_imagef(src1, 0 + slid*2 + k*8); + reg_b.s4567 = read_imagef(src1, 1 + slid*2 + k*8); + } + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*0).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*1).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*2).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*3).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*0).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*1).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*2).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*3).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*4).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*5).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*6).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*7).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*4).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*5).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*6).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*7).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + local float2 reduce_lm[SUBGROUP_SIZE * 3]; + if (grp == 1) { + reduce_lm[SUBGROUP_SIZE*0 + slid] = total_sum; + } + if (grp == 2) { + reduce_lm[SUBGROUP_SIZE*1 + slid] = total_sum; + } + if (grp == 3) { + reduce_lm[SUBGROUP_SIZE*2 + slid] = total_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*0 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*1 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*2 + slid]; + } + + if (grp == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(total_sum, 0, &(dst[gid * 2])); + } +} From 116a9f6ab79b518babc4c036d1bb56ac60a3da58 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Mon, 23 Mar 2026 15:33:49 -0700 Subject: [PATCH 043/249] hexagon: general DMA and Binary Op fixes for large strides (llama/20918) * hex-dma: make chained dma the default to handle newer models This also includes some new instrumentation that we can remove later. * hexagon: add uint32 dump helper * hexagon: use single-page VTCM allocation to avoid issues with large gather ops in ssm-conv ssm-conv uses HVX gather instruction and that instruction cannot handle cases where the base+offset spans page boundaries. * hexagon: update ssm-conv to make base-addr compute a bit easier to read * hex-dma: use 1d mode for reshaping, it supports sizes up to 24-bits (>16MB) * hex-bin: fix incorrect stride logic * hexagon: make sure repack buffs are dumped for verbose > 2 * hex-bin: consistently use dma_queue_push even for dummy dst transactions * hex-dma: start using 2d-wide mode on v75 and up The removes the need to deal with the 16-bit limitaion for the strides. * hex-bin: cleanup kernel selection logic * hex-bin: cleanup binary op core and fix transposed tensor handling * snapdragon: update run-bench to use larger ubatch and fa-on --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 12 +- ggml/src/ggml-hexagon/htp/binary-ops.c | 307 ++++++++++----------- ggml/src/ggml-hexagon/htp/hex-dma.c | 4 +- ggml/src/ggml-hexagon/htp/hex-dma.h | 307 ++++++++++++--------- ggml/src/ggml-hexagon/htp/hex-dump.h | 9 + ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 28 +- ggml/src/ggml-hexagon/htp/hvx-utils.h | 8 - ggml/src/ggml-hexagon/htp/main.c | 4 +- ggml/src/ggml-hexagon/htp/ssm-conv.c | 18 +- 9 files changed, 368 insertions(+), 329 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 8bcf5291c11..9c1ce93cc69 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -461,7 +461,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -480,7 +480,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -796,7 +796,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -814,7 +814,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -1149,7 +1149,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) e[7] = x[i * 8 + 7].e; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } @@ -1168,7 +1168,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) const uint8_t * y_q = y + 0; // quants first const uint8_t * y_e = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index ec90f22de52..1b0f97493bc 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -24,28 +24,26 @@ // Context for binary operations struct htp_binary_context { struct htp_ops_context * octx; - struct fastdiv_values dim1_div; - struct fastdiv_values dim2_div; - struct fastdiv_values dim12_div; + + struct fastdiv_values src0_dim1_div; // ne01 + struct fastdiv_values src0_dim2_div; // ne02 + struct fastdiv_values src0_dim12_div;// ne03 struct fastdiv_values src1_dim1_div; // ne11 struct fastdiv_values src1_dim2_div; // ne12 struct fastdiv_values src1_dim3_div; // ne13 - uint32_t nrows_per_thread; - bool split_at_ne01; - bool split_at_ne02; - - // Precomputed values uint32_t block_max; + uint32_t nrows_per_thread; size_t src0_row_size_aligned; size_t src1_row_size_aligned; size_t dst_row_size_aligned; - uint32_t src1_fetch_rows; // 1 or block_max - uint32_t src1_dma_stride; // 0 or stride + + bool split_at_ne01; + bool split_at_ne02; }; -#define htp_binary_preamble \ +#define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src1 = &octx->src1; \ struct htp_tensor * dst = &octx->dst; \ @@ -72,12 +70,11 @@ struct htp_binary_context { const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, - uint32_t ne01, uint32_t ne02) { +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) { uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint32_t rows_left = end_row - ir; @@ -191,6 +188,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; @@ -204,9 +203,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; @@ -215,7 +214,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -229,9 +228,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; // src1 indices (broadcast/repeat) @@ -255,9 +254,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); + p02 = fastdiv(prem, &bctx->src0_dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; @@ -282,6 +281,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -297,9 +298,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint32_t i13 = (ne13 == 1) ? 0 : i03; @@ -307,23 +308,23 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint32_t i11 = (ne11 == 1) ? 0 : i01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; - uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } for (uint32_t ir = start_row; ir < end_row; ) { uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; @@ -335,9 +336,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi } uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); @@ -345,9 +346,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); + p02 = fastdiv(prem, &bctx->src0_dim1_div); p01 = prem - p02 * ne01; uint32_t p13 = (ne13 == 1) ? 0 : p03; @@ -358,7 +359,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } @@ -373,15 +374,17 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src0.type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); - uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; @@ -391,15 +394,14 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint32_t ir_prefetch = start_row; int spad_idx = 0; - void * s1_ptr = (void *) src1_spad; + void * s1_ptr = (void *) src1_spad_base; for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -407,7 +409,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -415,7 +417,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, for (uint32_t ir = start_row; ir < end_row; ) { uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; for (uint32_t r = 0; r < current_block_size; r++) { @@ -425,21 +427,19 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -458,14 +458,16 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * const uint32_t src0_type = octx->src0.type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -473,11 +475,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -485,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -496,11 +497,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; @@ -521,11 +521,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -545,14 +544,16 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -560,11 +561,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -572,7 +572,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -583,11 +583,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; @@ -612,11 +611,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -646,6 +644,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t nb02 = src0->nb[2]; const uint32_t nb03 = src0->nb[3]; const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; @@ -657,8 +656,8 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -666,11 +665,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -678,7 +676,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -689,11 +687,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 @@ -712,11 +709,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); ir_prefetch += next_block_size; @@ -739,40 +735,36 @@ static int execute_op_binary(struct htp_ops_context * octx) { const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const size_t src0_row_size = src0->ne[0] * elem_size; const size_t src1_row_size = src1->ne[0] * elem_size; - const size_t dst_row_size = dst->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; - // Align to VLEN - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); bool is_add_id = (octx->op == HTP_OP_ADD_ID); bool is_scalar = !is_add_id && (src1->ne[0] == 1); - // Determine which kernel we will use to alloc memory and dispatch - bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) && + bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size); + + bool is_same_shape = !is_add_id && !is_scalar && !is_transposed && + (src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); - bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); - bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); - bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]); + bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]); size_t spad_row_total; - if (is_scalar) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); - } else if (is_row_bcast) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); - } else if (use_vector_same) { + if (is_same_shape) { spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); - } else if (is_add_id) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly } else { spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); } size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case if (is_row_bcast) { size_t needed_static = src1_row_size_aligned; @@ -782,28 +774,26 @@ static int execute_op_binary(struct htp_ops_context * octx) { } if (rows_per_buffer < 1) { - FARF(ERROR, "binary: VTCM too small\n"); - return HTP_STATUS_VTCM_TOO_SMALL; + FARF(ERROR, "binary: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; } octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; - if (is_scalar || use_complex || use_repeat || is_add_id) { - octx->src1_spad.size_per_thread = 0; - } else if (is_row_bcast) { + if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) { octx->src1_spad.size_per_thread = 0; } else { octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; if (is_row_bcast) { octx->src1_spad.size = src1_row_size_aligned; } else { octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; } - octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { return HTP_STATUS_VTCM_TOO_SMALL; @@ -823,46 +813,37 @@ static int execute_op_binary(struct htp_ops_context * octx) { } struct htp_binary_context bctx; - bctx.octx = octx; - bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; - bctx.block_max = rows_per_buffer; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + bctx.block_max = rows_per_buffer; bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned; bctx.dst_row_size_aligned = dst_row_size_aligned; - bctx.dim1_div = init_fastdiv_values(src0->ne[1]); - bctx.dim2_div = init_fastdiv_values(src0->ne[2]); - bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); - bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); - bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); - bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); - bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); - bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - - bctx.split_at_ne01 = (src0->ne[2] > 1) && - ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - bctx.split_at_ne02 = (src0->ne[3] > 1) && - ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); - - // Precompute specific kernel parameters - if (use_vector_same) { - bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; - bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; - } + bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); worker_callback_t worker_func; - if (is_add_id) worker_func = binary_job_add_id; - else if (is_scalar) worker_func = binary_job_scalar; - else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; - else if (use_vector_same) worker_func = binary_job_vector_same_shape; - else if (use_complex) worker_func = binary_job_vector_complex; - else worker_func = binary_job_element_repeat; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (is_same_shape) worker_func = binary_job_vector_same_shape; + else if (is_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; if (is_row_bcast) { dma_queue_pop(q); diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.c b/ggml/src/ggml-hexagon/htp/hex-dma.c index 44e1be40c5d..b66e2d2603c 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.c +++ b/ggml/src/ggml-hexagon/htp/hex-dma.c @@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) { q->capacity = capacity; q->idx_mask = capacity - 1; - q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); - memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d)); + memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d)); q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); memset(q->dptr, 0, capacity * sizeof(dma_ptr)); diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 9811a07599f..ff166cbcc7a 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -10,19 +10,84 @@ extern "C" { #endif +// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date +typedef struct dma_descriptor_1d_s { + void * next; + uint32_t size:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; +} dma_descriptor_1d; + +#if __HVX_ARCH__ < 75 + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t reserved0:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved1:24; + uint32_t row_size:16; + uint32_t nrows:16; + uint32_t src_stride:16; + uint32_t dst_stride:16; + uint32_t src_offset:16; + uint32_t dst_offset:16; +} dma_descriptor_2d; + +#else + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t dst_stride:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved0:24; + uint32_t row_size:24; + uint32_t nrows_lo:8; + uint32_t nrows_hi:8; + uint32_t src_stride:24; + uint32_t offset:24; + uint32_t reserved1:8; +} dma_descriptor_2d; + +#endif + typedef struct { - void *dst; + void *dst; const void *src; } dma_ptr; typedef struct { - hexagon_udma_descriptor_type1_t * desc; // descriptor pointers - hexagon_udma_descriptor_type1_t * tail; // tail pointer - dma_ptr * dptr; // dst/src pointers - uint32_t push_idx; - uint32_t pop_idx; - uint32_t capacity; - uint32_t idx_mask; + dma_descriptor_2d * desc; // descriptor pointers + dma_descriptor_2d * tail; // tail pointer + dma_ptr * dptr; // dst/src pointers + uint32_t push_idx; + uint32_t pop_idx; + uint32_t capacity; + uint32_t idx_mask; } dma_queue; dma_queue * dma_queue_create(size_t capacity); @@ -59,71 +124,87 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src) return p; } -static inline bool dma_queue_push(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t width, // width in bytes. number of bytes to transfer per row - size_t nrows) { +#if __HVX_ARCH__ < 73 +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 0; +#else +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 1; +#endif + +static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { - FARF(ERROR, "dma-push: queue full\n"); + FARF(HIGH, "dma-push: queue full\n"); return false; } - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; + dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx]; + desc->next = NULL; + desc->desc_size = 0; // 1D mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->order = 1; + desc->done = 0; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; + desc->size = size; + + q->dptr[q->push_idx] = dptr; + + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(HIGH, "dma-push: queue full\n"); + return false; + } + + dma_descriptor_2d * desc = &q->desc[q->push_idx]; desc->next = NULL; - desc->length = 0; - desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; - desc->dstbypass = 1; - desc->srcbypass = 1; -#if __HVX_ARCH__ >= 73 - desc->dstbypass = 1; - desc->srcbypass = 1; -#else - desc->dstbypass = 0; - desc->srcbypass = 1; -#endif - desc->order = 0; - desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; + desc->reserved0 = 0; + desc->reserved1 = 0; + desc->desc_size = 1; // 2d mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->src_comp = 0; + desc->dst_comp = 0; + desc->order = 1; + desc->done = 0; + desc->src_stride = src_stride; + desc->dst_stride = dst_stride; desc->src = (void *) dptr.src; desc->dst = (void *) dptr.dst; - desc->allocation = 0; - desc->padding = 0; - desc->roiwidth = width; - desc->roiheight = nrows; - desc->srcstride = src_row_size; - desc->dststride = dst_row_size; - desc->srcwidthoffset = 0; - desc->dstwidthoffset = 0; + desc->row_size = row_size; + +#if __HVX_ARCH__ < 75 + desc->desc_type = 0; // 2d (16-bit) mode + desc->nrows = nrows; + desc->src_offset = 0; + desc->dst_offset = 0; +#else + desc->desc_type = 9; // 2d (24-bit) mode + desc->nrows_lo = (nrows & 0xff); + desc->nrows_hi = (nrows >> 8); + desc->offset = 0; +#endif q->dptr[q->push_idx] = dptr; dmlink(q->tail, desc); q->tail = desc; - // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; return true; } -static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); -} - - -static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); -} - static inline dma_ptr dma_queue_pop(dma_queue * q) { dma_ptr dptr = { NULL }; @@ -131,12 +212,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { return dptr; } - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; + dma_descriptor_2d * desc = &q->desc[q->pop_idx]; // Wait for desc to complete while (1) { dmpoll(); - if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { + if (desc->done) { break; } // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); @@ -175,86 +256,62 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) { return q->capacity; } -// --------------------------------------------------------------------------- -// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth, -// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper -// transparently handles values that exceed the 16-bit limit and submits -// chained DMA transtions. -// -// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push. -// Case 2 (contiguous block): width == srcstride == dststride. Reshape the -// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a -// single descriptor, preserving async DMA behavior. -// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows -// one at a time. The first N-1 rows are pushed+popped synchronously; -// the last row is left async so the caller can pop it. -// --------------------------------------------------------------------------- -#define UDMA_MAX_FIELD_VAL 65535u - -static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) { - // Fast path: everything fits in 16 bits. - if (__builtin_expect( - width <= UDMA_MAX_FIELD_VAL && - nrows <= UDMA_MAX_FIELD_VAL && - src_stride <= UDMA_MAX_FIELD_VAL && - dst_stride <= UDMA_MAX_FIELD_VAL, 1)) { - return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows); - } +#if __HVX_ARCH__ < 75 - // Case 2: contiguous block (width == src_stride == dst_stride). - // Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535. - if (width == src_stride && width == dst_stride) { - size_t total = width * nrows; +// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535. +// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions. - // Pick the largest 128-byte-aligned sub_width that divides total evenly. - size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408 - while (sub_width > 0 && total % sub_width != 0) { - sub_width -= 128; - } - if (sub_width == 0) { - // Fallback: use original width (must fit) with adjusted nrows. - // This shouldn't happen for 128-aligned DMA sizes. - sub_width = width; - } - size_t sub_nrows = total / sub_width; - - // Handle sub_nrows > 65535 by issuing chunked descriptors. - const uint8_t *src = (const uint8_t *)dptr.src; - uint8_t *dst = (uint8_t *)dptr.dst; - size_t rows_done = 0; - while (rows_done < sub_nrows) { - size_t chunk = sub_nrows - rows_done; - if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL; - - dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width); - if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk)) - return false; +#define DMA_MAX_FIELD_VAL 65535u - rows_done += chunk; - // Complete all chunks without waiting except the last one, so the - // caller's single dma_queue_pop drains the final descriptor. - if (rows_done < sub_nrows) - dma_queue_pop_nowait(q); - } - return true; +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // Fast path: everything fits in 16 bits + if (nrows == 0 || __builtin_expect( + row_size <= DMA_MAX_FIELD_VAL && + nrows <= DMA_MAX_FIELD_VAL && + src_stride <= DMA_MAX_FIELD_VAL && + dst_stride <= DMA_MAX_FIELD_VAL, 1)) { + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); } - // Case 3: stride overflow — fall back to row-by-row. + // Contiguous block + // Use 1d DMA mode which supports sizes up to 24-bits (16MB) + if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) { + size_t total = row_size * nrows; + return dma_queue_push_single_1d(q, dptr, total); + } + + // Stride overflow — fall back to row-by-row. { - const uint8_t *src = (const uint8_t *)dptr.src; - uint8_t *dst = (uint8_t *)dptr.dst; + const uint8_t *src = (const uint8_t *) dptr.src; + uint8_t *dst = (uint8_t *) dptr.dst; for (size_t r = 0; r < nrows; ++r) { - dma_ptr p = dma_make_ptr(dst + r * dst_stride, - src + r * src_stride); - if (!dma_queue_push(q, p, 0, 0, width, 1)) - return false; - if (r + 1 < nrows) - dma_queue_pop_nowait(q); + dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride); + if (!dma_queue_push_single_1d(q, p, row_size)) + return false; + if (r + 1 < nrows) + dma_queue_pop(q); } return true; } } +#else // HVX_ARCH >= 75 + +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // On v75 and up we always use 2d 24-bit mode + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); +} + +#endif + +static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); +} + +static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/hex-dump.h b/ggml/src/ggml-hexagon/htp/hex-dump.h index e3badb57f92..19d173c2232 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dump.h +++ b/ggml/src/ggml-hexagon/htp/hex-dump.h @@ -21,6 +21,15 @@ static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t FARF(HIGH, "%s\n", str); } +static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { char str[1024], *p = str, *p_end = str + sizeof(str); p += snprintf(p, p_end - p, "%s: ", pref); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index c703a049426..a56356bee9f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -727,7 +727,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (use_dma_activation) { const size_t row_bytes = (size_t) params->k * sizeof(float); const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push_chained(ctx->dma[0], + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_f32_act, activation_chunk), row_bytes, stride_bytes, row_bytes, n_rows); dma_queue_pop(ctx->dma[0]); @@ -747,7 +747,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu { const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } @@ -765,7 +765,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); } @@ -891,7 +891,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co if (use_dma_activation) { const size_t row_bytes = (size_t) k * sizeof(float); const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push_chained(ctx->dma[0], + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_f32_act, activation_chunk), row_bytes, stride_bytes, row_bytes, n_rows); dma_queue_pop(ctx->dma[0]); @@ -916,7 +916,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co { const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } @@ -933,7 +933,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); } @@ -1104,7 +1104,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // because UDMA roiwidth is 16-bit and total size can exceed 65535. { const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); } for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { @@ -1120,7 +1120,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); } // Dequant + vscatter writes directly to [K, N] transposed tiles. @@ -1173,7 +1173,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds { // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } { @@ -1191,7 +1191,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); if (1 < n_chunk_cnt) { const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } // C0 @@ -1218,7 +1218,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // issue A_{i+2} if (i + 2 < n_chunk_cnt) { const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); } // wait for HMX (C_{i}) -- C_{i} is done @@ -1443,7 +1443,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict { const float *activation_block = x + mr * k + kk; - dma_queue_push_chained(ctx->dma[0], + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_scratch1, activation_block), k_blk_sz * sizeof(float), k * sizeof(float), @@ -1472,10 +1472,10 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE; // 2D DMA: quants sub-range - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), + dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), s.dst_stride, s.src_stride, s.quant_width, s.n_rows); // 2D DMA: scales sub-range - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), + dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), s.dst_stride, s.src_stride, s.scale_width, s.n_rows); } TIMER_STOP(fetch); diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 08343798794..a518ad37331 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,12 +15,4 @@ #include "hvx-div.h" #include "hvx-base.h" -#ifndef GATHER_TYPE -# if defined(__hexagon__) -# define GATHER_TYPE(_a) (intptr_t) _a -# else -# define GATHER_TYPE(_a) (HVX_Vector *) _a -# endif -#endif - #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index ef9cba8ecc1..70ba9f9f4fe 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -214,7 +214,7 @@ static int vtcm_alloc(struct htp_context * ctx) { HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_cache_mode(&attr, 1); - HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_hmx_param(&attr, 1); @@ -319,7 +319,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(64); + ctx->dma[i] = dma_queue_create(128); } // init worker pool diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index b3c1ef9572e..6b035810d57 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -151,7 +151,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void const int dr = scctx->nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const int ir = ir1 - ir0; + const uint32_t ir = ir1 - ir0; if (ir0 >= ir1) { return; // No work for this thread @@ -205,10 +205,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), - src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), - src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); + uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); + Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); @@ -222,10 +222,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), - src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), - src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); + uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); + Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); From eef7422d4d6bd336a9343b0a04b20f94ad9c80a2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Mar 2026 10:03:09 +0200 Subject: [PATCH 044/249] metal : add FA instantiations for HSK=512, HSV=512 (llama/20902) --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 14144aab087..2fbb274c5f9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1148,6 +1148,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && op->src[0]->ne[0] != 320 && + op->src[0]->ne[0] != 512 && op->src[0]->ne[0] != 576) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9c6b1c4f62b..9286675189d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6269,6 +6269,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6284,6 +6285,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6300,6 +6302,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6316,6 +6319,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6331,6 +6335,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6346,6 +6351,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6361,6 +6367,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6376,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6957,6 +6965,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) From 9e4e4c2401e6fa73bbaedab1e86512c34aee052c Mon Sep 17 00:00:00 2001 From: nuri Date: Tue, 24 Mar 2026 17:13:07 +0900 Subject: [PATCH 045/249] metal : add FLOOR, CEIL, ROUND, TRUNC unary ops (llama/20930) Co-authored-by: nryoo --- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 ++++ ggml/src/ggml-metal/ggml-metal-device.m | 4 ++++ ggml/src/ggml-metal/ggml-metal-impl.h | 4 ++++ ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++++++++++ 4 files changed, 28 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9162342ee98..89539bd7615 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; + case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break; + case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; + case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; + case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 2fbb274c5f9..cbef2fb4879 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1039,6 +1039,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ea471090cd8..eb2253e029a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -120,6 +120,10 @@ #define OP_UNARY_NUM_EXP 114 #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_UNARY_NUM_FLOOR 117 +#define OP_UNARY_NUM_CEIL 118 +#define OP_UNARY_NUM_ROUND 119 +#define OP_UNARY_NUM_TRUNC 120 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9286675189d..2074211594c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl( // TODO: precise implementation dst_ptr[i0] = (T) (exp(x) - 1); } + + if (FC_OP == OP_UNARY_NUM_FLOOR) { + dst_ptr[i0] = (T) floor(x); + } + + if (FC_OP == OP_UNARY_NUM_CEIL) { + dst_ptr[i0] = (T) ceil(x); + } + + if (FC_OP == OP_UNARY_NUM_ROUND) { + dst_ptr[i0] = (T) round(x); + } + + if (FC_OP == OP_UNARY_NUM_TRUNC) { + dst_ptr[i0] = (T) trunc(x); + } } #undef FC_OP From f2a8e65ea7e4b58cf862a832c2c1fabd2e6ff63f Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 25 Mar 2026 17:48:37 +0800 Subject: [PATCH 046/249] sycl : fix wrong variable check by assert (llama/20903) * fix wrong variable check by assert * use GGML api --- ggml/src/ggml-sycl/add-id.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp index 8929017a999..e0adc4fe423 100644 --- a/ggml/src/ggml-sycl/add-id.cpp +++ b/ggml/src/ggml-sycl/add-id.cpp @@ -56,7 +56,7 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float* dst_d = (float*)dst->data; const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + GGML_ASSERT(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); int threads = std::min((unsigned int)ne00, max_work_group_size); // cols From 3987857d2db803eddfc82d61d199913f2013dfab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 25 Mar 2026 11:53:16 +0100 Subject: [PATCH 047/249] llama: fix llama-model-saver (llama/20503) * llama : add fd-based model loading via llama_model_load_from_fd * llama : address review feedback for fd-based model loading * llama : use FILE pointer instead of fd in public API * llama : use FILE pointer consistently, address review feedback * fixup * fix tensor names * fix llama-model-saver * roundtrip tests * fixup * refactor tests * fix prints * fix model saving * fix CI, disable Chameleon * print seed --------- Co-authored-by: Siddhesh2377 --- ggml/include/gguf.h | 2 ++ ggml/src/ggml-impl.h | 1 - ggml/src/gguf.cpp | 33 +++++++++++++++++++++++---------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 79ee202062b..02d5f221c03 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -77,6 +77,7 @@ extern "C" { }; GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); //GGML_API struct gguf_context * gguf_init_from_buffer(..); @@ -189,6 +190,7 @@ extern "C" { // // write the entire context to a binary file + GGML_API bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta); GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 92568655956..0639db362e7 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -773,6 +773,5 @@ inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); -GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta); #endif // __cplusplus diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index cbeedf6c4b6..ab3cc974867 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -394,7 +394,11 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & bu gguf_write_out(ctx, gw, only_meta); } +bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta) { + GGML_ASSERT(file); + + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data: %s\n", __func__, ex.what()); + return false; + } + return true; +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1516,17 +1533,13 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - try { - gguf_writer_file gw(file); - gguf_write_out(ctx, gw, only_meta); - } catch (const std::runtime_error& ex) { - GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what()); - fclose(file); - return false; + const bool success = gguf_write_to_file_ptr(ctx, file, only_meta); + if (!success) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s'\n", __func__, fname); } fclose(file); - return true; + return success; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { From 495b77aec29017b13a2dfe5d29b35eb677056d08 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:57:40 +0100 Subject: [PATCH 048/249] mtmd: Add DeepSeekOCR Support (llama/17400) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mtmd: llama.cpp DeepSeekOCR support init commit * loading sam tensors * mtmd: fix vision model processing * deepseek-ocr clip-vit model impl * mtmd: add DeepSeek-OCR LM support with standard attention * mtmd: successfully runs DeepSeek-OCR LM in llama-cli * mtmd: Fix RoPE type for DeepSeek-OCR LM. * loading LM testing Vision model loading * sam warmup working * sam erroneous return corrected * clip-vit: corrected cls_embd concat * clip-vit: model convert qkv_proj split * corrected combining of image encoders' results * fix: update callback for ffn_moe_weighted and add callback for attn_out in deepseek2 model * concat image_newline and image_seperator tokens * visual_model warmup (technically) works * window partitioning using standard ggml ops * sam implementation without using CPU only ops * clip: fixed warnings * Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr * mtmd: fix get_rel_pos * mtmd: fixed the wrong scaler for get_rel_pos * image encoding technically works but the output can't be checked singe image decoding fails * mtmd: minor changed * mtmd: add native resolution support * - image encoding debugged - issues fixed mainly related wrong config like n_patches etc. - configs need to be corrected in the converter * mtmd: correct token order * - dynamic resizing - changes are concerning PR https://github.com/sfallah/llama.cpp/pull/4 * mtmd: quick fix token order * mtmd: fix danling pointer * mtmd: SAM numerically works * mtmd: debug CLIP-L (vit_pre_ln) * mtmd: debug CLIP-L & first working DeepSeek-OCR model * mtmd : add --dsocr-mode CLI argument for DeepSeek-OCR resolution control & all native resolution modes work * mtmd: simplify SAM patch embedding * mtmd: adapt Pillow image resizing function * mtmd: simplify DeepSeek-OCR dynamic resolution preprocessing * mtmd: remove --dsocr-mode argument * mtmd: refactor code & remove unused helper functions * mtmd: fix tensor names for image newlines and view separator * clean up * reverting automatically removed spaces * reverting automatically removed spaces * mtmd: fixed bad ocr check in Deepseek2 (LM) * mtmd: support combined QKV projection in buid_vit * using common build_attn in sam * corrected code-branch when flash-attn disabled enabling usage of --flash-attn option * mtmd: minor fix * minor formatting and style * fixed flake8 lint issues * minor editorconfig-check fixes * minor editorconfig-check fixes * mtmd: simplify get_rel_pos * mtmd: make sam hparams configurable * mtmd: add detailed comments for resize_bicubic_pillow * mtmd: fixed wrong input setting * mtmd: convert model in FP16 * mtmd: minor fix * mtmd: remove tweak to llama-mtmd-cli & deepseek-ocr template * fix: test-1.jpg ORC issue with small (640) resolution setting min-resolution base (1024) max large (1280) for dynamic-resolution * minor: editconfig-check fix * merge with changes from https://github.com/ggml-org/llama.cpp/pull/17909 added new opt to tests.sh to disable flash-attn * minor: editconfig-check fix * testing deepseek-ocr quick and dirty test script comparing results of Qwen2.5-VL vs DeepSeek-OCR * quick and (potential) dirty merge with https://github.com/ggml-org/llama.cpp/pull/17909 * refactoring, one single builder function and static helpers * added deepseek-ocr test to tests.sh * minor formatting fixes * check with fixed expected resutls * minor formatting * editorconfig-check fix * merge with changes from https://github.com/ggml-org/llama.cpp/pull/18042 * minor - added GLM-4.6V to big tests - added missing deps for python test * convert: minor fix * mtmd: format code * convert: quick fix * convert: quick fix * minor python formatting * fixed merge build issue * merge resolved - fixed issues in convert - tested several deepseek models * minor fix * minor * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * - removed clip_is_deepseekocr - removed redundant RESIZE_ALGO_BICUBIC_PILLOW resize-algo - simplified image-preprocessing - removed/simplified debug functions * - cleaning commented out code * fixing instabilities issues reintroducing resize_bicubic_pillow * - use f16 model for deepseek-ocr test - ignore llama-arch test for deepseek-ocr * rename fc_w --> mm_fc_w * add links to OCR discussion * cleaner loading code * add missing .weight to some tensors * add default jinja template (to be used by server) * move test model to ggml-org * rolling back upscale change * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: bluebread Co-authored-by: Sigbjørn Skjæret Co-authored-by: Xuan Son Nguyen Co-authored-by: Xuan-Son Nguyen --- ggml/src/ggml.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4c0764a0ac5..e9b6720c0af 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4962,6 +4962,7 @@ static struct ggml_tensor * ggml_interpolate_impl( GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); // TODO: implement antialias for modes other than bilinear GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); + GGML_ASSERT(a->type == GGML_TYPE_F32); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); @@ -5307,6 +5308,7 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(q->ne[3] == v->ne[3]); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); From a050c7d1bf2aae985cba6896cdbb6644383f20bf Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:19:14 -0700 Subject: [PATCH 049/249] CUDA & CPU: support F32 kernel type for `CONV_TRANSPOSE_2D` (llama/17094) * Refactor CUDA 2D transpose implementation to support multiple kernel types and improve parameter handling - Introduced a `conv2d_transpose_params` struct for better parameter management. - Updated `conv2d_transpose_kernel` to be templated for different kernel types (float and half). - Modified `ggml_cuda_conv_2d_transpose_p0` to handle both F16 and F32 kernel types. - Enhanced test cases to validate functionality for both kernel types. * Refactor test cases for 2D convolution transpose to support dynamic kernel types - Updated `test_conv_transpose_2d` structure to improve parameter handling by reordering constructor arguments. - Enhanced test case generation to iterate over kernel types, allowing for flexible testing of different configurations. - Removed hardcoded kernel type instances in favor of a loop for better maintainability and scalability. * Refactor ggml_compute_forward_conv_transpose_2d to support both F16 and F32 tensor types. * Refactor conv2d transpose kernel to use a template for kernel type, enhancing flexibility for different data types. Update test cases to include both F16 and F32 tensor types for comprehensive coverage. * Update ggml/src/ggml-cuda/conv2d-transpose.cu Co-authored-by: Aman Gupta * Update ggml/src/ggml-cpu/ggml-cpu.c Co-authored-by: Aman Gupta * Refactor conv2d transpose implementation by removing the conv2d_transpose_params struct and dispatching with direct kernel launch. * Enhance cpu conv2d transpose implementation by introducing a templated kernel type for improved flexibility with F16 and F32 data types. --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cpu/ggml-cpu.c | 8 ++- ggml/src/ggml-cpu/ops.cpp | 69 ++++++++++++++++++------- ggml/src/ggml-cuda/conv2d-transpose.cu | 66 +++++++++++++++-------- ggml/src/ggml-cuda/conv2d-transpose.cuh | 1 + 4 files changed, 102 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b323bd9b06..df17cc55300 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2871,8 +2871,12 @@ struct ggml_cplan ggml_graph_plan( const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32); + + cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03; + cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12; + } break; case GGML_OP_TOP_K: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3f85e531daa..d950972c83e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6923,16 +6923,15 @@ void ggml_compute_forward_conv_3d( ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); } -// ggml_compute_forward_conv_transpose_2d - -void ggml_compute_forward_conv_transpose_2d( - const ggml_compute_params * params, - ggml_tensor * dst) { +template +static void ggml_compute_forward_conv_transpose_2d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -6943,7 +6942,7 @@ void ggml_compute_forward_conv_transpose_2d( const int nk = ne00*ne01*ne02*ne03; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == sizeof(float)); if (ith == 0) { @@ -6951,12 +6950,12 @@ void ggml_compute_forward_conv_transpose_2d( // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02); + kernel_t * dst_data = wdata + i02*ne01*ne00*ne03; for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; @@ -6968,13 +6967,17 @@ void ggml_compute_forward_conv_transpose_2d( // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + nk; for (int i12 = 0; i12 < ne12; i12++) { for (int i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + kernel_t * dst_data = wdata + i11*ne10*ne12; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + if constexpr (std::is_same_v) { + dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + } else { + dst_data[i10*ne12 + i12] = src[i10]; + } } } } @@ -6996,21 +6999,27 @@ void ggml_compute_forward_conv_transpose_2d( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; + kernel_t * const wdata_src = wdata + nk; for (int i2 = ip0; i2 < ip1; i2++) { // Cout float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; for (int i11 = 0; i11 < ne11; i11++) { for (int i10 = 0; i10 < ne10; i10++) { const int i1n = i11*ne10*ne12 + i10*ne12; for (int i01 = 0; i01 < ne01; i01++) { for (int i00 = 0; i00 < ne00; i00++) { float v = 0; - ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + if constexpr (std::is_same_v) { + ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } else { + ggml_vec_dot_f32(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; } } @@ -7019,6 +7028,28 @@ void ggml_compute_forward_conv_transpose_2d( } } +void ggml_compute_forward_conv_transpose_2d( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_conv_2d_dw struct ggml_conv_2d_dw_params { diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cu b/ggml/src/ggml-cuda/conv2d-transpose.cu index 03224e404d3..6cbd6f879e6 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cu +++ b/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -1,12 +1,20 @@ -#include - #include "conv2d-transpose.cuh" -#include "ggml.h" - -__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, const int in_w, const int in_h, const int out_w, - const int out_h, const int kernel_w, const int kernel_h, const int stride, - const int c_in, const int c_out, const int batches) { +#include "convert.cuh" + +template +static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, + const kernel_t * __restrict__ kernel, + float * __restrict__ output, + const int in_w, + const int in_h, + const int out_w, + const int out_h, + const int kernel_w, + const int kernel_h, + const int stride, + const int c_in, + const int c_out, + const int batches) { const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; const int total_elements = out_w * out_h * c_out * batches; @@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { for (int kh = 0; kh < kernel_h; ++kh) { int in_y = out_y_idx - kh; - if (in_y < 0 || in_y % stride) continue; + if (in_y < 0 || in_y % stride) { + continue; + } in_y /= stride; - if (in_y >= in_h) continue; + if (in_y >= in_h) { + continue; + } for (int kw = 0; kw < kernel_w; ++kw) { int in_x = out_x_idx - kw; - if (in_x < 0 || in_x % stride) continue; + if (in_x < 0 || in_x % stride) { + continue; + } in_x /= stride; - if (in_x >= in_w) continue; + if (in_x >= in_w) { + continue; + } const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; - float input_val = input[input_idx]; - half kern_val = kernel[kernel_idx]; + float input_val = input[input_idx]; + kernel_t kern_val = kernel[kernel_idx]; - accumulator += input_val * (float) kern_val; + accumulator += input_val * ggml_cuda_cast(kern_val); } } } @@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * kernel = dst->src[0]; const ggml_tensor * input = dst->src[1]; - GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); const float * input_data = (const float *) input->data; float * output_data = (float *) dst->data; - const half * kernel_data = (const half *) kernel->data; + const void * kernel_data = kernel->data; const int input_w = input->ne[0]; const int input_h = input->ne[1]; @@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(ggml_is_contiguous(kernel)); GGML_ASSERT(ggml_is_contiguous(dst)); - const int total = (output_w * output_h * channels_out * batches); + const int total = output_w * output_h * channels_out * batches; const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; - conv2d_transpose_kernel<<>>( - input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, - channels_in, channels_out, batches); + if (kernel->type == GGML_TYPE_F16) { + conv2d_transpose_kernel<<>>( + input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + + } else { + conv2d_transpose_kernel<<>>( + input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + } } diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cuh b/ggml/src/ggml-cuda/conv2d-transpose.cuh index c9430b24850..72889c5f0fa 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cuh +++ b/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -1,4 +1,5 @@ #include "common.cuh" #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 + void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From eb747f3def7907b11b661a8af1450dd508fd6c9d Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Thu, 26 Mar 2026 01:54:03 -0700 Subject: [PATCH 050/249] ggml-cuda: Add NVFP4 dp4a kernel (llama/20644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added check for dst_t to cuda_cast template for float Restored ggml_cuda_ue4m3_to_fp32, changed vecdot ints to int32ts Added CUDART/HIP Check and HIP/fp8 include Added NVFP4 to Test-backend-ops Added hip_fp8_e4m3 to __nv_fp8_e4m3 typedef --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 17 ++++++++++++ ggml/src/ggml-cuda/convert.cu | 43 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++- ggml/src/ggml-cuda/mmvq.cu | 8 ++++++ ggml/src/ggml-cuda/vecdotq.cuh | 32 +++++++++++++++++++++++ ggml/src/ggml-cuda/vendors/cuda.h | 5 ++-- ggml/src/ggml-cuda/vendors/hip.h | 6 +++++ 7 files changed, 118 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 36d8a3aaab2..9f93c70d21d 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -799,6 +799,16 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { +#ifdef FP8_AVAILABLE + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +#else + NO_DEVICE_CODE; +#endif // FP8_AVAILABLE +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; @@ -931,6 +941,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI_MXFP4; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_NVFP4; + static constexpr int qr = QR_NVFP4; + static constexpr int qi = QI_NVFP4; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b70492c7d6c..79ccfe568a2 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_mxfp4<<>>(vx, y); } +template +static __global__ void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + const int64_t i = blockIdx.x; + const int tid = threadIdx.x; + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_cuda_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_cuda_cast(d * kvalues_mxfp4[q >> 4]); +} + +template +static void dequantize_row_nvfp4_cuda( + const void * vx, + dst_t * y, + const int64_t k, + cudaStream_t stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + dequantize_block_nvfp4<<>>(vx, y, k); +} template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, @@ -715,6 +754,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -766,6 +807,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a31e843e153..cc80eb3ffc2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1297,7 +1297,12 @@ static void ggml_cuda_op_mul_mat_cublas( const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = + src0->type != GGML_TYPE_NVFP4 && + (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && + dst->op_params[0] == GGML_PREC_DEFAULT; if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); @@ -4781,6 +4786,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: +#ifdef FP8_AVAILABLE + case GGML_TYPE_NVFP4: +#endif // FP8_AVAILABLE case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 024b3d8cf22..66bd8beeae7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -15,6 +15,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; @@ -41,6 +42,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; @@ -626,6 +628,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index ab803aca21b..40b2b41e7e8 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -322,6 +322,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_nvfp4_q8_1( + const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & kbx, + const int32_t & iqs) { + + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds); + sum += d * float(sumi); + } + + return sum; +} #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index ba032cfab4b..07bc47df3b8 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,9 +6,10 @@ #include #include -#if CUDART_VERSION >= 12050 +#if CUDART_VERSION >= 11080 #include -#endif // CUDART_VERSION >= 12050 +#define FP8_AVAILABLE +#endif // CUDART_VERSION >= 11080 #if CUDART_VERSION >= 12080 #include diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 35d1e1a0639..9d9ba1ee219 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -235,6 +235,12 @@ typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; +#if HIP_VERSION >= 60200000 +#include +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +#define FP8_AVAILABLE +#endif // HIP_VERSION >= 60200000 + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { From 07237ff99e479c364d778c1f4d1cad8729305dca Mon Sep 17 00:00:00 2001 From: ihb2032 <40718643+ihb2032@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:08:41 +0800 Subject: [PATCH 051/249] fix(ggml): correct RISC-V ISA string canonical ordering for RVV in CMake (llama/20888) Signed-off-by: ihb2032 --- ggml/src/ggml-cpu/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 1a1bbc9f2be..beebc4760d2 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -460,6 +460,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if(NOT GGML_CPU_ALL_VARIANTS) set(MARCH_STR "rv64gc") + if (GGML_RVV) + string(APPEND MARCH_STR "v") + endif() + if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") endif() @@ -467,7 +471,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_XTHEADVECTOR) string(APPEND MARCH_STR "_xtheadvector") elseif (GGML_RVV) - string(APPEND MARCH_STR "_v") if (GGML_RV_ZVFH) string(APPEND MARCH_STR "_zvfh") endif() @@ -475,12 +478,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(APPEND MARCH_STR "_zvfbfwma") endif() endif() + if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") endif() if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() # Begin with the lowest baseline From 1848f994e324840cd9e1b67f9b2685868546debe Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 26 Mar 2026 08:52:21 -0700 Subject: [PATCH 052/249] opencl: allow large buffer for adreno (llama/20997) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4dddcd82cfa..c40e1f2d391 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -394,6 +394,9 @@ struct ggml_backend_opencl_context { bool fp16_support; bool has_vector_subgroup_broadcast; bool disable_fusion; + + bool adreno_has_large_buffer; + bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; @@ -787,6 +790,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; + if (backend_ctx->adreno_use_large_buffer) { + compile_opts += " -qcom-enable-large-buffer "; + } + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); // add @@ -3020,6 +3027,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support + backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; // fp16 is required if (!backend_ctx->fp16_support) { @@ -3086,6 +3095,18 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use large buffer for Adreno + backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + cl_int err; // A local ref of cl_context for convenience @@ -5660,6 +5681,11 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b cl_int err; cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS && backend_ctx->adreno_use_large_buffer) { + cl_mem_properties props[] = { 0x41A6 /* CL_LARGE_BUFFER_QCOM */, 1, 0 }; + mem = clCreateBufferWithProperties(backend_ctx->context, props, CL_MEM_READ_WRITE, size, NULL, &err); + } + if (err != CL_SUCCESS) { GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); return nullptr; From 45a708343104837e1bc74a983b94d1101fc6e13a Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 26 Mar 2026 23:06:33 +0100 Subject: [PATCH 053/249] hip: use fnuz fp8 for conversion on CDNA3 (llama/21040) --- ggml/src/ggml-cuda/common.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9f93c70d21d..7d7f20af3a0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -802,7 +802,13 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { #ifdef FP8_AVAILABLE const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. +#if defined(GGML_USE_HIP) && defined(CDNA3) + // ROCm dose not support fp8 in software on devices with fp8 hardware, + // but CDNA3 supports only e4m3_fnuz (no inf). + const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits); +#else const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); +#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP) return static_cast(xf) / 2; #else NO_DEVICE_CODE; From b564a99ed63abdf646a30f00b54bd5557238e900 Mon Sep 17 00:00:00 2001 From: ren <189031187+lathrys-at@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:05:21 -0700 Subject: [PATCH 054/249] metal : Fix dimension constraint violation in matmul2d descriptor (llama/21048) Updates Metal tensor API test probe to fix the dimension constraint violation in the matmul2d descriptor (at least one value must be a multiple of 16). --- ggml/src/ggml-metal/ggml-metal-device.m | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index cbef2fb4879..17d51b11b6e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -690,7 +690,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor(); \n" @@ -740,7 +740,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor(); \n" From 7f466e237b02974aed2c1c49b13b3c847c1fa55b Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 27 Mar 2026 10:59:35 +0200 Subject: [PATCH 055/249] rpc : proper handling of data pointers to CPU buffers (llama/21030) The compute graph may contain tensors pointing to CPU buffers. In these cases the buffer address is serialized as 0 and sent over the wire. However, the data pointer is serialized as-is and this prevents proper validation on the server side. This patches fixes this by serializing the data pointer as 0 for non-RPC buffers and doing proper validation on the server side. closes: #21006 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 0ed2c0dce60..16f6abdffd6 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -589,8 +589,10 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; + result.data = reinterpret_cast(tensor->data); } else { result.buffer = 0; + result.data = 0; } for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { result.ne[i] = tensor->ne[i]; @@ -606,7 +608,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { } result.view_src = reinterpret_cast(tensor->view_src); result.view_offs = tensor->view_offs; - result.data = reinterpret_cast(tensor->data); // Avoid sending uninitialized data over the wire memset(result.name, 0, sizeof(result.name)); @@ -1443,9 +1444,11 @@ ggml_tensor * rpc_server::create_node(uint64_t id, const rpc_tensor * tensor = it_ptr->second; struct ggml_tensor * result = deserialize_tensor(ctx, tensor); - if (result == nullptr || result->buffer == nullptr) { - GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n", - __func__, result == nullptr ? "tensor" : "buffer", id); + if (result == nullptr) { + return nullptr; + } + if (result->buffer == nullptr && result->data != nullptr) { + GGML_LOG_ERROR("[%s] invalid data ptr", __func__); return nullptr; } tensor_map[id] = result; From 52699f6d193058353e0832caaf8c055c705e72a6 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Fri, 27 Mar 2026 09:22:41 -0700 Subject: [PATCH 056/249] hexagon: support for IQ4_NL and MXFP4 (llama/21018) * ggml-hexagon: add IQ4_NL and MXFP4 HMX matmul support - Add IQ4_NL quantization type support to Hexagon backend (buffer set/get tensor repack, mul_mat, mul_mat_id dispatch) - Implement HVX IQ4_NL vec_dot kernels (1x1, 2x1, 2x2) with LUT-based 4-bit index to int8 kvalue dequantization - Add MXFP4 HMX dequantization path with E8M0 scale conversion, including batch-4 fast path and single-tile fallback - Unify quantized row size / scale offset logic to handle Q4_0, Q8_0, IQ4_NL, and MXFP4 in the DMA fetch path * ggml-hexagon: fix SKIP_QUANTIZE src1 address mismatch in mixed-quant models * Fix the pragma indent --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 37 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 209 +++++++++++- ggml/src/ggml-hexagon/htp/htp-ctx.h | 6 + ggml/src/ggml-hexagon/htp/main.c | 10 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 380 +++++++++++++++++++++ 5 files changed, 619 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9c1ce93cc69..dd604db4333 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q8_0_q8x4x2(tensor, data, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) + repack_q4_0_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q8x4x2_q8_0(data, tensor, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_0(data, tensor, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if (src0->ne[0] % 32) { return false; @@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if ((src0->ne[0] % 32)) { return false; @@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { delete backend; } +// Map weight type to its activation quantization family. +// Types in the same family produce identical Q8 formats in VTCM and can +// safely share quantized activation data via SKIP_QUANTIZE. +// When adding a new quantized type, assign it the correct family here. +static inline int act_quant_family(enum ggml_type wtype) { + switch (wtype) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + return 1; // Q8x4x2 + default: + return 0; // unknown / not quantized + } +} + static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type)); + return (op0 && op0->src[1] == op1->src[1] && + act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) && + act_quant_family(op0->src[0]->type) != 0); } static inline bool is_compute_op(ggml_tensor *node) @@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, + "please update hexagon_type to match ggml_type"); const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL"); const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index a56356bee9f..4ff2b36de96 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, @@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes #define HMX_X4X2_SCALES_PER_BLK 8 -#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes +#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) +#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) static inline void swap_ptr(void **p1, void **p2) { void *t = *p1; @@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { switch (weight_type) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: - return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb case HTP_TYPE_Q8_0: - return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + case HTP_TYPE_MXFP4: + return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb default: return 0; } @@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } +// --- MXFP4 E8M0 scale conversion and dequantization --- +// +// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. +// Scalar loads from the stack array execute on the scalar pipeline, in parallel +// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop. +// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 +// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. + +typedef struct { + __fp16 v[8] __attribute__((aligned(16))); +} mxfp4_scales_t; + +static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { + mxfp4_scales_t s; + HVX_Vector v = hvx_vmemu(e8m0_8); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + hvx_vec_store_u(s.v, 16, vh); + return s; +} + +static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { + return hvx_vec_splat_f16(scales.v[idx]); +} + +// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. +static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, + bool upper_nibbles, + int sub_blk, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); +} + +// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). +static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, + bool upper_nibbles, + int sub_blk_base, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales, + HVX_Vector out[4]) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), + mxfp4_extract_splat(scales, sub_blk_base + 1)); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), + mxfp4_extract_splat(scales, sub_blk_base + 3)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + out[0] = v_lo; + out[1] = Q6_V_vror_VR(v_lo, 64); + out[2] = v_hi; + out[3] = Q6_V_vror_VR(v_hi, 64); +} + // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. // Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. // Output: vtcm_dst in tile-major FP16 layout. @@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); - const int qrow_size = is_q4 ? (k_block / 2) : k_block; + const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2); - const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) - ? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); + const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : + (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + hvx_vmem(q4_0_to_fp16_lut); // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 @@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int ct = t / n_k_tiles; // column tile index int kt = t % n_k_tiles; // K tile index - // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- - if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + // --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row --- + if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) && + ((t + 3) / n_k_tiles == ct)) { int blk_idx = (kt * 32) / QK_Q4_0x4x2; int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 bool upper = (sub_blk_base >= 4); @@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( continue; } + // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales --- + if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales + + __fp16 * tile_bases[4]; + for (int g = 0; g < 4; g++) { + tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; + } + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + const uint8_t * r0 = vtcm_src + row0 * row_stride; + const uint8_t * r1 = vtcm_src + row1 * row_stride; + + // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0[4], v1[4]; + dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + if (row1 < n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + } else { + v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + } + + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { + (void) *(volatile HVX_Vector *) (tile_bases[g]); + } + + t += 4; + continue; + } + // --- Single-tile fallback --- __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - if (is_q4) { + if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) { int blk_idx = (kt * 32) / QK_Q4_0x4x2; int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk >= 4); @@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } (void) *(volatile HVX_Vector *)(tile_base); + } else if (weight_type == HTP_TYPE_MXFP4) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; + bool upper = (sub_blk >= 4); + int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t * r0 = vtcm_src + row0 * row_stride; + const uint8_t * r1 = vtcm_src + row1 * row_stride; + + // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); + HVX_Vector v1; + if (row1 < n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); + } else { + v1 = Q6_V_vzero(); + } + + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *) (tile_base); } else { // Q8_0 int blk_idx = (kt * 32) / QK_Q8_0x4x2; @@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict { qweight_fetch_task_state_t s; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); const int blk_start = kk / QK_Q4_0x4x2; const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = is_q4 ? (k / 2) : k; + const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + const int scale_blk_size = + (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; s.dst = vtcm_scratch0; s.src = w + nc * row_stride; s.n_rows = n_blk_sz; s.src_stride = row_stride; s.dst_stride = sub_row_stride; - s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2); - s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2); - s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE; - s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE; + s.quant_off = + (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); + s.quant_width = + (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); + s.scale_off = full_qrow + blk_start * scale_blk_size; + s.scale_width = nb_sub * scale_blk_size; // 2D DMA: quants sub-range dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index a92acfa0a85..6f1917fa2cb 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -31,6 +31,12 @@ struct htp_context { uint32_t opmask; + // Cached src1 spad position from the last quantize pass. + // When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM + // at this address; the matmul must read from here instead of recomputing + // the offset (which depends on the current op's src0 size). + uint8_t * prev_src1_spad; + // HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX) #ifdef HTP_HAS_HMX int hmx_enabled; // Runtime flag: HMX initialisation succeeded diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 70ba9f9f4fe..49f34b5f7d1 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx, return; } - // HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights. - // Other types (e.g. MXFP4) fall back to HVX. + // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // Other types fall back to HVX. { uint32_t wtype = req->src0.type; - if (wtype != HTP_TYPE_F16 && - wtype != HTP_TYPE_Q4_0 && - wtype != HTP_TYPE_Q8_0 && - wtype != HTP_TYPE_IQ4_NL) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && + wtype != HTP_TYPE_MXFP4) { proc_matmul_req(ctx, req, bufs, n_bufs); return; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 73aaba79ebf..24b7bad6876 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -60,6 +60,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, }; +// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue +// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = { + 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0, + 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -68,6 +78,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales static inline size_t q8x4x2_row_size(uint32_t ne) { @@ -921,6 +998,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +// ======== IQ4_NL x Q8_0 vec_dot kernels ======== +// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). +// Scale format is identical to Q4_0 (fp16 scales). + +static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, + float * restrict s0, + float * restrict s1, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0, + const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); +} + static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); @@ -2393,6 +2757,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; return 0; + case HTP_TYPE_IQ4_NL: + mmctx->type = "iq4nlx4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; @@ -2556,6 +2926,13 @@ int op_matmul(struct htp_ops_context * octx) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + // Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it + octx->ctx->prev_src1_spad = octx->src1_spad.data; + } else { + // SKIP_QUANTIZE: Q8 data lives at the address written by the previous + // quantize pass. The current op may have a different src0 size (e.g. + // IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong. + octx->src1_spad.data = octx->ctx->prev_src1_spad; } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { @@ -2659,6 +3036,9 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + octx->ctx->prev_src1_spad = octx->src1_spad.data; + } else { + octx->src1_spad.data = octx->ctx->prev_src1_spad; } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { From 759f0084b4172f891412fbfd22a1b20a4f25f2c1 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 28 Mar 2026 08:44:56 +0100 Subject: [PATCH 057/249] vulkan: add noncontiguous GLU support (llama/21081) * vulkan: add noncontiguous GLU support * fix compile issue --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 39 +++++++++++++------ .../ggml-vulkan/vulkan-shaders/glu_head.glsl | 10 +++++ .../ggml-vulkan/vulkan-shaders/glu_main.glsl | 22 ++++++++--- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 221e6fa04e9..15ed5b2a79d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants { uint32_t mode; // 0: default, 1: swapped, 2: split float alpha; // for swiglu_oai float limit; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t ne01; + uint32_t ne02; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t ne11; + uint32_t ne12; }; struct vk_op_unary_push_constants { @@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); } - vk::DeviceCreateInfo device_create_info; + vk::DeviceCreateInfo device_create_info{}; std::vector device_extensions; vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); @@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->name = GGML_VK_NAME + std::to_string(idx); - device_create_info = { - vk::DeviceCreateFlags(), - device_queue_create_infos, - {}, - device_extensions - }; + device_create_info + .setFlags(vk::DeviceCreateFlags()) + .setQueueCreateInfos(device_queue_create_infos) + .setPEnabledExtensionNames(device_extensions); device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); @@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const float alpha = op_params_f[2]; const float limit = op_params_f[3]; - GGML_ASSERT(ggml_is_contiguous(src0)); - if (!split) { GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); } else { @@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)dst->ne[0], mode, alpha, - limit + limit, + (uint32_t)(src0->nb[1] / src0->nb[0]), + (uint32_t)(src0->nb[2] / src0->nb[0]), + (uint32_t)(src0->nb[3] / src0->nb[0]), + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + (uint32_t)(dst->nb[1] / dst->nb[0]), + (uint32_t)(dst->nb[2] / dst->nb[0]), + (uint32_t)(dst->nb[3] / dst->nb[0]), + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2] }); } @@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous(op->src[0]) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type); default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 2168989340b..95298922d83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -16,4 +16,14 @@ layout (push_constant) uniform parameter uint mode; float alpha; float limit; + uint nb01; + uint nb02; + uint nb03; + uint ne01; + uint ne02; + uint nb11; + uint nb12; + uint nb13; + uint ne11; + uint ne12; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl index 85cf65a9eca..359461306a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -8,22 +8,32 @@ void main() { const uint row = i / p.ne20; const uint col = i - row * p.ne20; + const uint i3 = row / (p.ne01 * p.ne02); + const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01; + const uint i1 = row % p.ne01; + const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col; + + const uint dst_i3 = row / (p.ne11 * p.ne12); + const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11; + const uint dst_i1 = row % p.ne11; + const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col; + if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } } From 95ea8f9bfb03a15db08a8989966fd1ae3361e20d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Mar 2026 13:23:24 +0300 Subject: [PATCH 058/249] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6557fb46cbe..58863dc6bbb 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c044a8eeae2591faa0950c8b5e514cbc4bbfc4ca +404fcb9d7c96989569e68c9e7881ee3465a05c50 From 166c20b473d5f4d04052e699f992f625ea2a2fdd Mon Sep 17 00:00:00 2001 From: Daniel Worthington-Bodart Date: Fri, 17 Apr 2026 12:36:27 +0100 Subject: [PATCH 059/249] whisper : add stateless VAD detect + explicit state reset for streaming (#3677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit whisper_vad_detect_speech resets LSTM state on every call, which is correct for batch processing but prevents temporal continuity when calling per-chunk in a streaming loop. Add whisper_vad_detect_speech_no_reset (skips buffer clear) and whisper_vad_reset_state (explicit clear between utterances). Existing whisper_vad_detect_speech is now a thin wrapper — zero behavior change for current callers. Co-authored-by: Claude Opus 4.6 (1M context) --- include/whisper.h | 10 ++++++++++ src/whisper.cpp | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..b5dcdb2917a 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -695,6 +695,16 @@ extern "C" { const float * samples, int n_samples); + // Like whisper_vad_detect_speech, but does not reset LSTM state. + // Use for streaming: call whisper_vad_reset_state() between utterances. + WHISPER_API bool whisper_vad_detect_speech_no_reset( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples); + + // Reset LSTM hidden/cell states to zero. + WHISPER_API void whisper_vad_reset_state(struct whisper_vad_context * vctx); + WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx); WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx); diff --git a/src/whisper.cpp b/src/whisper.cpp index 86bfafeaad8..2f356da0f06 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -5083,7 +5083,11 @@ struct whisper_vad_context * whisper_vad_init_with_params( return vctx; } -bool whisper_vad_detect_speech( +void whisper_vad_reset_state(whisper_vad_context * vctx) { + ggml_backend_buffer_clear(vctx->buffer, 0); +} + +bool whisper_vad_detect_speech_no_reset( struct whisper_vad_context * vctx, const float * samples, int n_samples) { @@ -5095,9 +5099,6 @@ bool whisper_vad_detect_speech( WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples); WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks); - // Reset LSTM hidden/cell states - ggml_backend_buffer_clear(vctx->buffer, 0); - vctx->probs.resize(n_chunks); WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks); @@ -5165,6 +5166,14 @@ bool whisper_vad_detect_speech( return true; } +bool whisper_vad_detect_speech( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples) { + whisper_vad_reset_state(vctx); + return whisper_vad_detect_speech_no_reset(vctx, samples, n_samples); +} + int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) { return segments->data.size(); } From fc674574ca27cac59a15e5b22a09b9d9ad62aafe Mon Sep 17 00:00:00 2001 From: jinweihan Date: Sun, 19 Apr 2026 22:12:57 -0700 Subject: [PATCH 060/249] bench : sync submit-results URL to ggml-org (#3769) The project moved from ggerganov/ to ggml-org/ and the README already references the new URL in both places it mentions issue #89 (README.md and examples/bench/README.md). Syncing the two remaining hardcoded URLs in examples/bench/bench.cpp and examples/bench.wasm/emscripten.cpp. The old URL still redirects, so this is cosmetic. --- examples/bench.wasm/emscripten.cpp | 2 +- examples/bench/bench.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bench.wasm/emscripten.cpp b/examples/bench.wasm/emscripten.cpp index 083397db057..7e9f277f66e 100644 --- a/examples/bench.wasm/emscripten.cpp +++ b/examples/bench.wasm/emscripten.cpp @@ -45,7 +45,7 @@ void bench_main(size_t index) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 049473d4f32..84915c56a8a 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -157,7 +157,7 @@ static int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); From 763a4540521ae191c68e79397506b01e3d9c9d78 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 30 Mar 2026 18:34:29 +0300 Subject: [PATCH 061/249] ggml : bump version to 0.9.9 (ggml/1449) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index c780077acaa..a739cca4218 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 8) +set(GGML_VERSION_PATCH 9) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 9e96d390f7dc63544ebbdafe36902879c217104a Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sun, 29 Mar 2026 06:40:13 -0700 Subject: [PATCH 062/249] hexagon: dma optimizations (mostly fixing regressions) (llama/21137) * hex-fa: add simple dma cache for Mask I noticed that we were refetch the mask rows over and over. This simple cache avoids that. * hex-dma: unset in-order desc bit which caused signficant perf regression We don't rely on true in order processing of the DMA descriptors anywhere. Turns out this mode caused significant regression of around 3-4 TPS during token gen. * hex-rope: update comment to clarify that we don't need in-order DMA completions --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 12 ++-- ggml/src/ggml-hexagon/htp/hex-dma.h | 75 ++++++++++++++++++---- ggml/src/ggml-hexagon/htp/rope-ops.c | 4 +- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 6dc978dd68a..0c9bc785620 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); + dma_cache m_cache; + dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE); + for (uint32_t ir = ir0; ir < ir1; ++ir) { const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); @@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row - dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); - dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = size_q_block * 1; octx->src1_spad.size_per_thread = factx.size_k_block * 2; octx->src2_spad.size_per_thread = factx.size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread); + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index ff166cbcc7a..7685473f463 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -143,7 +143,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t desc->desc_size = 0; // 1D mode desc->src_bypass = dma_src_l2_bypass_on; desc->dst_bypass = dma_dst_l2_bypass_on; - desc->order = 1; + desc->order = 0; desc->done = 0; desc->src = (void *) dptr.src; desc->dst = (void *) dptr.dst; @@ -151,8 +151,12 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = (dma_descriptor_2d *) desc; + if (size) { + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + } else { + desc->done = 1; + } // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; @@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t desc->dst_bypass = dma_dst_l2_bypass_on; desc->src_comp = 0; desc->dst_comp = 0; - desc->order = 1; + desc->order = 0; desc->done = 0; desc->src_stride = src_stride; desc->dst_stride = dst_stride; @@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = desc; + if (nrows) { + dmlink(q->tail, desc); + q->tail = desc; + } else { + desc->done = 1; + } // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; @@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { dma_descriptor_2d * desc = &q->desc[q->pop_idx]; // Wait for desc to complete - while (1) { - dmpoll(); - if (desc->done) { - break; - } + while (!desc->done) { // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + dmpoll(); } dptr = q->dptr[q->pop_idx]; @@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_ return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); } +#define DMA_CACHE_MAX_SIZE 64U + +typedef struct { + uint8_t *base; + uint32_t line_size; + uint32_t capacity; + uint32_t src[DMA_CACHE_MAX_SIZE]; + uint16_t age[DMA_CACHE_MAX_SIZE]; +} dma_cache; + +static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity) +{ + c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity; + c->base = base; + c->line_size = line_size; + + for (unsigned i=0; i < c->capacity; i++) { + c->src[i] = 0; + c->age[i] = 0; + } +} + +static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows) +{ + uint32_t o_idx = 0; + uint16_t o_age = 0; + uint8_t * dst = 0; + + for (unsigned i=0; i < c->capacity; i++) { + if (c->src[i] == (uint32_t) src) { + c->age[i] = 0; + dst = c->base + (i * c->line_size); nrows = 0; // dummy dma + // FARF(ERROR, "dma-cache: found %p", src); + } else { + c->age[i]++; + if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; } + } + } + if (!dst) { + // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src); + c->age[o_idx] = 0; + c->src[o_idx] = (uint32_t) src; + dst = c->base + o_idx * c->line_size; // normal nrows dma + } + + return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows); +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index be9469538f6..ecedadb0fea 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } - // Skip DMA transactions from prev block (if any) - // No need to wait for these since the DMA is setup for in-order processing + // Skip output DMA transactions from prev block (if any) + // No need to wait for those here since we're explicitly waiting for the latest prefecthes below. for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } // Compute loop From 6b67c918797be49d4c9e67eda05efd490f8e123d Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 29 Mar 2026 22:05:18 +0530 Subject: [PATCH 063/249] Optimize MOE GEMV kernel for BS > 1. (llama/20905) * Optimize MOE GEMV kernel for BS > 1. The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row. New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync). This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization. * Remove em-dashes * Cherry-pick changes from @am17an PR https://github.com/ggml-org/llama.cpp/pull/20885 to enable small_k optimization only for cases where it benefits Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8 * Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/ggml-cuda.cu | 19 +- ggml/src/ggml-cuda/mmvq.cu | 393 +++++++++++++++++++++++++++----- ggml/src/ggml-cuda/mmvq.cuh | 5 +- 3 files changed, 358 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cc80eb3ffc2..d1239b1c5f7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) { + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); + if (ne2 <= mmvq_mmid_max) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } @@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } // [TAG_MUL_MAT_ID_CUDA_GRAPHS] - if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) { - // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs - // TODO: figure out a way to enable for larger batch sizes, without hurting performance - // ref: https://github.com/ggml-org/llama.cpp/pull/18958 - use_cuda_graph = false; + if (node->op == GGML_OP_MUL_MAT_ID) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); + if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 + use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif + } } if (!use_cuda_graph) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 66bd8beeae7..8d80d1dd9a7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } +// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID. +// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE. +// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 6; + case GGML_TYPE_Q4_1: return 6; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_0: return 6; + case GGML_TYPE_Q5_1: return 6; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 7; + case GGML_TYPE_IQ3_S: return 6; + case GGML_TYPE_IQ3_XXS: return 7; + case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 5; + case GGML_TYPE_IQ1_M: return 5; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 5; + case GGML_TYPE_Q4_1: return 5; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 5; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_K: return 6; + case GGML_TYPE_Q6_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 6; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 7; + case GGML_TYPE_IQ1_M: return 7; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 7; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 7; + case GGML_TYPE_Q4_1: return 7; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_0: return 7; + case GGML_TYPE_Q5_1: return 7; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 5; + case GGML_TYPE_Q8_0: return 7; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +// Host function: returns the max batch size for the current arch+type at runtime. +int get_mmvq_mmid_max_batch(ggml_type type, int cc) { + // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return get_mmvq_mmid_max_batch_pascal_older(type); + } + // AMD + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } + return MMVQ_MAX_BATCH_SIZE; +} + +// Device constexpr: returns the max batch size for the current arch+type at compile time. +template +static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { +#if defined(RDNA4) + return get_mmvq_mmid_max_batch_rdna4(type); +#elif defined(RDNA3) + return get_mmvq_mmid_max_batch_rdna3(type); +#elif defined(RDNA2) || defined(RDNA1) + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); +#elif defined(CDNA) + return get_mmvq_mmid_max_batch_cdna(type); +#elif defined(GCN) + return get_mmvq_mmid_max_batch_gcn(type); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE) + return MMVQ_MAX_BATCH_SIZE; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + return get_mmvq_mmid_max_batch_turing_plus(type); +#else + return get_mmvq_mmid_max_batch_pascal_older(type); +#endif +} + static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { @@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q( const uint32_t channel_dst = blockIdx.y; - uint32_t token_idx = 0; uint32_t channel_x; uint32_t channel_y; uint32_t sample_dst; - if constexpr (is_multi_token_id) { - // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case - token_idx = blockIdx.z; - channel_x = ids[channel_dst + token_idx * ids_stride]; - channel_y = fastmodulo(channel_dst, nchannels_y); - sample_dst = 0; - } else { - channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - sample_dst = blockIdx.z; - } + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; - if constexpr (is_multi_token_id) { - y += token_idx*stride_col_y; - } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; - if constexpr (is_multi_token_id) { - dst += token_idx*stride_col_dst; - } - // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q( } } +// Dedicated MoE multi-token kernel. +// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst) +// Block: (warp_size, ncols_dst) - each warp handles one token independently. +// No shared memory reduction needed since each warp works alone. +template +__launch_bounds__(get_mmvq_mmid_max_batch_for_device()*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q_moe( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, + float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const uint32_t token_idx = threadIdx.y; + const int row0 = c_rows_per_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + if (token_idx >= ncols_dst) { + return; + } + + const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; + const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); + + const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y; + const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x; + + // partial sum for each thread + float tmp[c_rows_per_block] = {0.0f}; + + for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (threadIdx.x % (qi/vdr)); + +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + + // Warp-level reduction only - no shared memory needed +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] = warp_reduce_sum(tmp[i]); + } + + // Write results + if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) { + dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x]; + } +} + template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, @@ -425,7 +660,7 @@ static std::pair calc_launch_params( return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } +template +static void mul_mat_vec_q_moe_launch( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride, + const int warp_size, const int nchannels_dst, cudaStream_t stream) { + + constexpr int rows_per_block = 2; // 2 gives best perf based on tuning + const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; + const dim3 block_nums(nblocks_rows, nchannels_dst); + const dim3 block_dims(warp_size, ncols_dst); + + mul_mat_vec_q_moe<<>>( + vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride); +} + template static void mul_mat_vec_q_switch_ncols_dst( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, @@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst( const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const mmvq_parameter_table_id table_id = get_device_table_id(cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_ids = ids != nullptr; + const auto should_use_small_k = [&](int c_ncols_dst) { + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + + constexpr std::array iq_slow_turing = { + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + }; + constexpr std::array iq_slow_other = { + GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + }; + constexpr std::array slow_pascal = { + GGML_TYPE_IQ3_S, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + }; + + const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING; + const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA; + + if (is_nvidia_turing_plus) { + if (ncols_dst == 1 && + std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) { + use = false; + } + } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) || + (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) || + GGML_CUDA_CC_IS_RDNA(cc)) { + use = false; + } + + return use; + }; + if (has_ids && ncols_dst > 1) { - // Multi-token MUL_MAT_ID path only - single-token goes through regular path below - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + // Multi-token MUL_MAT_ID path - dedicated MoE kernel + mul_mat_vec_q_moe_launch( + vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride, warp_size, nchannels_dst, stream); return; } @@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst( case 1: { constexpr int c_ncols_dst = 1; - // When K is small, increase rows_per_block to match nwarps so each warp has more work to do - // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_iter_1warp = vdr * warp_size / qi; - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); - const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + bool use_small_k = should_use_small_k(c_ncols_dst); + if (use_small_k) { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, - warp_size, table_id, true); - mul_mat_vec_q_switch_fusion( + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); } else { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, - warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); } } break; case 2: { diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 8a154631f69..6bf0a8e8677 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,7 +1,10 @@ #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID + +// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, +// based on the quantization type and GPU architecture (compute capability). +int get_mmvq_mmid_max_batch(ggml_type type, int cc); void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); From 40ddc5a5b911851f4867207c80d8db2eb238388f Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Mon, 30 Mar 2026 17:05:11 +0300 Subject: [PATCH 064/249] rpc : fix misleading error log (llama/21184) When RPC is running with a remote backend which doesn't have init_tensor function (like CPU and Metal), the server log gets full with error messages saying that init_tensor is being called with null buffer which is incorrect. This patch fixes this. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 16f6abdffd6..1378ba9f5bf 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1340,7 +1340,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { if (buffer && buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); } else { - GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + if (!buffer) { + GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n"); + } } if (tensor->extra != nullptr) { From 75b9543856158561584afe59772713ded1e82e95 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 30 Mar 2026 16:20:00 +0200 Subject: [PATCH 065/249] CUDA : Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 (llama/21181) * CUDA: Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 We wrongly calculated offset_grid as `ceildiv(nrows, block_size)`, while it must be `ceildiv(nrows + 1, block_size)`. As a consequence, we had uninitialized values in `offset_iterator[nrows]` for the case when `nrows % block_size == 0`. Fixes #21162 * Reduce nrows in test case to 256, don't need 768 --- ggml/src/ggml-cuda/argsort.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 4896669c32a..38fdf3678c1 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -47,9 +47,11 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, #ifdef STRIDED_ITERATOR_AVAILABLE auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); #else - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + // offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size + const int nrows_offset = nrows + 1; + ggml_cuda_pool_alloc offsets_alloc(pool, nrows_offset); int * offset_iterator = offsets_alloc.get(); - const dim3 offset_grid((nrows + block_size - 1) / block_size); + const dim3 offset_grid((nrows_offset + block_size - 1) / block_size); init_offsets<<>>(offset_iterator, ncols, nrows); #endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); From 6ac5a50005e7080d1c1b293c8e753ea135a9f325 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Mon, 30 Mar 2026 12:19:16 -0700 Subject: [PATCH 066/249] opencl: add q4_K gemm and gemv kernels for Adreno (llama/20919) * opencl: add q4_K gemm and gemv kernels for Adreno * opencl: fix whitespace * opencl: add workarounds for compiler bugs on older devices * opencl: handle fp16 denorm on X Elite * opencl: fix kernel build error * opencl: fix whitespace * opencl: make q4_K cvt kernels signature consistent --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 312 +++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 75 ++++- .../kernels/gemm_noshuffle_q4_k_f32.cl | 172 ++++++++++ .../kernels/gemv_noshuffle_q4_k_f32.cl | 318 ++++++++++++++++++ 5 files changed, 877 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index af29f3b8f4c..540942b195d 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -114,6 +114,8 @@ set(GGML_OPENCL_KERNELS gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q4_k_f32 + gemm_noshuffle_q4_k_f32 gemv_noshuffle_q6_k_f32 gemm_noshuffle_q6_k_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c40e1f2d391..0f6628c377d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -538,6 +538,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q4_K_noshuffle; + cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; @@ -720,6 +722,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_k_f32; + cl_kernel kernel_gemm_noshuffle_q4_k_f32; cl_kernel kernel_gemv_noshuffle_q6_K_f32; cl_kernel kernel_gemm_noshuffle_q6_K_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -932,6 +936,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); @@ -2619,6 +2625,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q4_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable " " -cl-fast-relaxed-math"; @@ -5060,12 +5105,25 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + #endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -5076,6 +5134,20 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } if (tensor->type == GGML_TYPE_Q6_K) { @@ -5516,12 +5588,60 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Transpose q, d, dm back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; @@ -9688,6 +9808,192 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_q4_k_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_k = (ggml_tensor_extra_cl_q4_K *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -10014,6 +10320,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q4_k x fp32 + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); + return; + } + // q6_K x fp32 if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 34930dfbe6a..81fe17fa10f 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -424,13 +424,17 @@ kernel void kernel_restore_block_q8_0_trans( // Convert the block_q4_K format to 4 separate arrays (AOS -> SOA). // This kernel does not deshuffle the bits. // Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle +// version and they are not used in this kernel. //------------------------------------------------------------------------------ kernel void kernel_convert_block_q4_K( global struct block_q4_K * src0, global uchar * dst_q, global uchar * dst_s, global half * dst_d, - global half * dst_dm + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); @@ -451,12 +455,15 @@ kernel void kernel_convert_block_q4_K( // Restore block_q4_K from flattened arrays. // Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle ones. kernel void kernel_restore_block_q4_K( global uchar * src_q, global uchar * src_s, global half * src_d, global half * src_dm, - global struct block_q4_K * dst + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); @@ -475,6 +482,70 @@ kernel void kernel_restore_block_q4_K( } } +kernel void kernel_convert_block_q4_K_noshuffle( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_K_noshuffle( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..99fd1fd7bf1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl @@ -0,0 +1,172 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q4_k_f32( + global const ushort * src0_q, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = (bits4.s0 & 0x000F) * scale.s0 - mval.s0; + dequantized_weights.s1 = (bits4.s1 & 0x000F) * scale.s1 - mval.s1; + dequantized_weights.s2 = (bits4.s2 & 0x000F) * scale.s2 - mval.s2; + dequantized_weights.s3 = (bits4.s3 & 0x000F) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x00F0) >> 4) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x00F0) >> 4) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x00F0) >> 4) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x00F0) >> 4) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x0F00) >> 8) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x0F00) >> 8) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x0F00) >> 8) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x0F00) >> 8) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0xF000) >> 12) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0xF000) >> 12) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0xF000) >> 12) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0xF000) >> 12) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..dd1e2b55c0b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl @@ -0,0 +1,318 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_k_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} From 952c66237de555d87a1ae3f39948fe6a1b6cdfb5 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Tue, 31 Mar 2026 18:31:50 +0800 Subject: [PATCH 067/249] sycl : enhance fattn perf (llama/21185) --- ggml/src/ggml-sycl/fattn-tile.hpp | 83 ++++++++++++++++--------------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index 29fd0f8c9ec..c4d24613a55 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64) return 0; } @@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const sycl::half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; auto load = [&] (const int n) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const int stride_j = warp_size >> n; if (stride_j == 0) { @@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, flash_attn_tile_load_tile (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); @@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, } if (k_KQ_0 + nbatch_K < DKQ) { - item_ct1.barrier(); // Sync not needed on last iteration. + item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration. } } @@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, const int k_VKQ_max, const int col_Q_0, float * KQ_max_new_shared) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } if constexpr (np == 1) { - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { static_assert(cpw == 1, "bad cpw"); if (item_ct1.get_local_id(2) == 0) { KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0]; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np]; KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); } @@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { flash_attn_tile_load_tile (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 #pragma unroll @@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } } #endif // SYCL_FAST_FP16 - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } } @@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q, } } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); // Main loop over KV cache: const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; @@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q, return; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll for (int ip = 1; ip < np; ++ip) { @@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm constexpr size_t nbytes_shared = 0; - if constexpr (DV <= 256) { - if (Q->ne[1] > 16/ncols2) { - constexpr int cols_per_block = 32; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if (DV < 512 && Q->ne[1] < 32) { + if constexpr (ncols2 <= 32) { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } - } - - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; - } - - if constexpr (ncols2 <= 8) { - if (Q->ne[1] > 4/ncols2) { - constexpr int cols_per_block = 8; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } } From 5ffe58838dbb34db41f1f1c1db6c739c1b7fe83b Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 31 Mar 2026 22:00:51 +0800 Subject: [PATCH 068/249] CANN: fix multi-thread set_tensor race conditions (llama/20151) * CANN: fix multi-thread set_tensor race conditions When ollama calls ggml_backend_tensor_set from multiple threads (each writing a different chunk of the same tensor), the CANN backend had three concurrency issues: 1. Quantized tensors (Q4_0/Q8_0) require a full-tensor format transform before uploading to device. Per-chunk transforms produced corrupt data. 2. ND-to-NZ weight conversion requires complete tensor data on device. Per-chunk conversion operated on incomplete data. 3. The global g_nz_workspaces array had unprotected concurrent access. Fix by introducing a TensorSetTracker that accumulates write progress per tensor. For quantized tensors, raw data is staged in a host buffer and the transform + upload is deferred until all chunks arrive. For NZ weights, chunks are uploaded directly but conversion is deferred. The tracker and its staging buffer are released immediately after post-processing completes. Add per-device mutex to g_nz_workspaces to prevent data races. * CANN: fix L2_NORM ignoring eps parameter The L2_NORM implementation was not using the eps parameter from op_params, causing incorrect results when eps is large (e.g. 10.0). The CPU reference computes scale = 1/fmaxf(norm, eps), so add a Clamp step to clamp the norm to at least eps before dividing. * ggml/cann: compare op_params for POOL_2D in ACL graph cache matching When ACL graph mode is enabled, the graph LRU cache checks whether a cached graph matches the current computation graph. Previously, GGML_OP_POOL_2D was not included in the op_params comparison, so two POOL_2D nodes with different pooling parameters (kernel size, stride, padding) but identical tensor shapes and addresses could incorrectly reuse a cached graph, leading to wrong results or aclnn errors. Add GGML_OP_POOL_2D to the list of ops that require op_params matching in ggml_graph_node_properties::has_matching_properties(). * cann: fix ACL graph cache matching by adding tensor type and unconditional op_params comparison The ACL graph LRU cache was incorrectly reusing cached graphs for operations with different tensor types or op_params, causing test failures for CPY (f16 vs bf16), POOL_2D, L2_NORM, NORM_MUL_ADD, RMS_NORM_MUL_ADD, and ADD_RMS_NORM. Changes: - Add node_type and src_type[] fields to ggml_graph_node_properties so the cache can distinguish tensors with different types but identical ne/nb (e.g. f16 and bf16 both have 2-byte elements) - Compare op_params unconditionally for all ops instead of only for SCALE/UNARY/GLU/ROPE/POOL_2D --- ggml/src/ggml-cann/aclnn_ops.cpp | 10 +++ ggml/src/ggml-cann/common.h | 30 ++++--- ggml/src/ggml-cann/ggml-cann.cpp | 129 +++++++++++++++++++++++++++---- 3 files changed, 145 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index adb4d68e868..a950475fc3b 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -434,6 +434,9 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); @@ -456,6 +459,13 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float p_value = 2.0f; acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); + + // Clamp norm to at least eps: scale = 1/fmaxf(norm, eps) + acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); + float flt_max = FLT_MAX; + acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); } diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 5f960548cd2..1c6e685c38c 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -216,14 +216,16 @@ struct ggml_cann_pool_alloc { #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { // dst tensor - void * node_address; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; + void * node_address; + ggml_type node_type; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; // src tensor - void * src_address[GGML_MAX_SRC]; - int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; - size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + ggml_type src_type[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; // op ggml_op node_op; @@ -247,6 +249,10 @@ struct ggml_graph_node_properties { return false; } + if (node->type != this->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != this->ne[i]) { return false; @@ -262,6 +268,10 @@ struct ggml_graph_node_properties { return false; } + if (node->src[i]->type != this->src_type[i]) { + return false; + } + for (int d = 0; d < GGML_MAX_DIMS; d++) { if (node->src[i]->ne[d] != this->src_ne[i][d]) { return false; @@ -277,10 +287,7 @@ struct ggml_graph_node_properties { } } - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){ - return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; - } - return true; + return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } }; @@ -322,6 +329,7 @@ struct ggml_cann_graph { prop.node_address = node->data; prop.node_op = node->op; + prop.node_type = node->type; std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); @@ -329,10 +337,12 @@ struct ggml_cann_graph { for (int src = 0; src < GGML_MAX_SRC; ++src) { if (node->src[src]) { prop.src_address[src] = node->src[src]->data; + prop.src_type[src] = node->src[src]->type; std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); } else { prop.src_address[src] = nullptr; + prop.src_type[src] = GGML_TYPE_COUNT; std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 6f26e91e046..40fe3d82ecc 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -36,10 +36,13 @@ #include #include #include +#include #include #include #include +#include #include +#include #define GGML_COMMON_DECL_C @@ -770,6 +773,21 @@ std::unique_ptr ggml_backend_cann_context::new_pool_for_device(i } // cann buffer + +/** + * @brief Tracks multi-threaded write progress for a single tensor. + * + * When multiple threads call set_tensor on different chunks of the same tensor, + * this tracker accumulates progress and defers post-processing (quantized format + * transform or ND-to-NZ conversion) until all data has been written. + */ +struct TensorSetTracker { + std::mutex mtx; ///< Protects concurrent access to this tracker + size_t bytes_written = 0; ///< Accumulated bytes written so far + size_t total_bytes = 0; ///< Target size (full tensor) + std::vector host_buffer; ///< Host staging buffer for quantized tensors +}; + /** * @brief Context for managing a CANN buffer associated with a specific device. * @@ -780,6 +798,9 @@ struct ggml_backend_cann_buffer_context { int32_t device; ///< The device ID associated with this buffer context. void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer. + std::mutex tracker_mutex; ///< Protects the trackers map + std::unordered_map> trackers; + /** * @brief Constructor to initialize the CANN buffer context. * @@ -792,6 +813,31 @@ struct ggml_backend_cann_buffer_context { * @brief Destructor to free the device memory allocated for the buffer. */ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } + + /** + * @brief Get or create a tracker for the given tensor. + */ + TensorSetTracker * get_or_create_tracker(ggml_tensor * tensor) { + std::lock_guard lock(tracker_mutex); + auto key = tensor->data; + auto it = trackers.find(key); + if (it == trackers.end()) { + auto tracker = std::make_unique(); + tracker->total_bytes = ggml_nbytes(tensor); + auto * ptr = tracker.get(); + trackers[key] = std::move(tracker); + return ptr; + } + return it->second.get(); + } + + /** + * @brief Remove the tracker for the given tensor. + */ + void remove_tracker(ggml_tensor * tensor) { + std::lock_guard lock(tracker_mutex); + trackers.erase(tensor->data); + } }; // cann buffer type @@ -1124,6 +1170,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer * designed to be used with a global array, one per device. */ struct ggml_cann_nz_workspace { + std::mutex mtx; // Protects ptr/allocated from concurrent access void * ptr; // Pointer to allocated device buffer size_t allocated; // Size of currently allocated buffer in bytes @@ -1190,13 +1237,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES]; * @note The workspace buffer used in this function is managed globally and reused * across calls. This reduces overhead from repeated memory allocation and deallocation. */ -static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) { - acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset); +static void weight_format_to_nz(ggml_tensor * tensor, int device) { + acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, 0); uint64_t workspaceSize = 0; aclOpExecutor * executor; // TransMatmulWeight ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor)); + + std::lock_guard lock(g_nz_workspaces[device].mtx); // Avoid frequent malloc/free of the workspace. g_nz_workspaces[device].realloc(workspaceSize); @@ -1210,7 +1259,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) * @brief Set tensor data in a CANN buffer. * * This function sets tensor data in a CANN buffer, handling transformations - * if needed based on the tensor's type. + * if needed based on the tensor's type. It supports multi-threaded calls + * where different threads write different chunks of the same tensor. + * + * For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and + * the format transform is deferred until all chunks are written. + * For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ + * conversion is deferred until all chunks are written. * * @param buffer The CANN buffer where the tensor data will be set. * @param tensor Pointer to the tensor whose data will be set. @@ -1226,26 +1281,72 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); - // TODO: refer to cann(#6017), it use thread's default stream. - // For acl, synchronous functions use this default stream. - // Why aclrtSynchronizeDevice? // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (!need_transform(tensor->type)) { + + bool is_quantized = need_transform(tensor->type); + bool is_nz = !is_quantized && tensor->type != GGML_TYPE_BF16 && weight_to_nz && + is_matmul_weight((const ggml_tensor *) tensor); + + // Plain tensor (not quantized, not NZ): direct copy, no tracking needed + if (!is_quantized && !is_nz) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && tensor->type != GGML_TYPE_BF16 - && is_matmul_weight((const ggml_tensor *) tensor)) { + return; + } + + // Single-shot write (full tensor at once): handle directly without tracking overhead + if (offset == 0 && size == ggml_nbytes(tensor)) { + if (is_quantized) { + void * transform_buffer = malloc(size); + ggml_backend_cann_transform(tensor, data, transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } else { + // NZ weight GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - weight_format_to_nz(tensor, offset, ctx->device); + ACL_CHECK(aclrtMemcpy(tensor->data, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + weight_format_to_nz(tensor, ctx->device); } + return; + } + + // Chunked write: use tracker to accumulate progress and defer transform/conversion + TensorSetTracker * tracker = ctx->get_or_create_tracker(tensor); + std::unique_lock lock(tracker->mtx); + + if (is_quantized) { + // Stage data in host buffer; transform requires full tensor data + if (tracker->host_buffer.empty()) { + tracker->host_buffer.resize(tracker->total_bytes); + } + memcpy(tracker->host_buffer.data() + offset, data, size); } else { - void * transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); + // NZ weight: upload chunk to device immediately, defer conversion + ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + } - ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); - free(transform_buffer); + tracker->bytes_written += size; + + // All chunks received: perform deferred transform/conversion + if (tracker->bytes_written >= tracker->total_bytes) { + if (is_quantized) { + void * transform_buffer = malloc(tracker->total_bytes); + ggml_backend_cann_transform(tensor, tracker->host_buffer.data(), transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, tracker->total_bytes, transform_buffer, tracker->total_bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } + + if (is_nz) { + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + weight_format_to_nz(tensor, ctx->device); + } + + // Unlock before removing tracker, as remove_tracker destroys the mutex + lock.unlock(); + ctx->remove_tracker(tensor); } } From 21b9dd6789eac3db4e152aca87c727874e2f0cf1 Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Wed, 1 Apr 2026 12:58:53 +0300 Subject: [PATCH 069/249] ggml-webgpu: port all AOT operators to JIT (llama/20728) * port cpy pipeline to shader lib with JIT compilation * port glu pipeline to shader lib with JIT compilation * port rope pipeline to shader lib with JIT compilation * port soft_max pipeline to shader lib with JIT compilation * removed unused functions from embed_wgsl.py which were used for old AOT template expansion --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 325 ++++++++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 224 ++++-------- ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 81 +++++ .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 107 +----- ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl | 155 +++++++++ ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl | 224 ++++++++++++ .../ggml-webgpu/wgsl-shaders/soft_max.wgsl | 245 +++++++++++++ 7 files changed, 1097 insertions(+), 264 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 59861ac16cc..97863f40412 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** Cpy **/ + +struct ggml_webgpu_cpy_pipeline_key { + ggml_type src_type; + ggml_type dst_type; + + bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { + return src_type == other.src_type && dst_type == other.dst_type; + } +}; + +struct ggml_webgpu_cpy_pipeline_key_hash { + size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +/** Glu **/ + +struct ggml_webgpu_glu_pipeline_key { + ggml_glu_op glu_op; + ggml_type type; + bool split; + + bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { + return glu_op == other.glu_op && type == other.type && split == other.split; + } +}; + +struct ggml_webgpu_glu_pipeline_key_hash { + size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.glu_op); + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.split); + return seed; + } +}; + +/** Rope **/ + +struct ggml_webgpu_rope_pipeline_key { + ggml_type type; + bool inplace; + bool has_ff; + + bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { + return type == other.type && inplace == other.inplace && has_ff == other.has_ff; + } +}; + +struct ggml_webgpu_rope_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.has_ff); + return seed; + } +}; + +/** SoftMax **/ + +struct ggml_webgpu_soft_max_pipeline_key { + ggml_type mask_type; + bool has_mask; + bool has_sink; + bool inplace; + + bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { + return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && + inplace == other.inplace; + } +}; + +struct ggml_webgpu_soft_max_pipeline_key_hash { + size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.mask_type); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sink); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib { std::unordered_map set_rows_pipelines; std::unordered_map set_pipelines; + std::unordered_map cpy_pipelines; + std::unordered_map glu_pipelines; + std::unordered_map + rope_pipelines; + std::unordered_map + soft_max_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_cpy_pipeline_key key = { + .src_type = context.src0->type, + .dst_type = context.dst->type, + }; + + auto it = cpy_pipelines.find(key); + if (it != cpy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "cpy"; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("SRC_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src type for cpy shader"); + } + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("DST_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported dst type for cpy shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cpy, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + cpy_pipelines[key] = pipeline; + return cpy_pipelines[key]; + } + + webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_glu_pipeline_key key = { + .glu_op = ggml_get_glu_op(context.dst), + .type = context.dst->type, + .split = (context.src1 != nullptr), + }; + + auto it = glu_pipelines.find(key); + if (it != glu_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "glu"; + + switch (key.glu_op) { + case GGML_GLU_OP_REGLU: + defines.push_back("OP_REGLU"); + variant += "_reglu"; + break; + case GGML_GLU_OP_GEGLU: + defines.push_back("OP_GEGLU"); + variant += "_geglu"; + break; + case GGML_GLU_OP_SWIGLU: + defines.push_back("OP_SWIGLU"); + variant += "_swiglu"; + break; + case GGML_GLU_OP_SWIGLU_OAI: + defines.push_back("OP_SWIGLU_OAI"); + variant += "_swiglu_oai"; + break; + case GGML_GLU_OP_GEGLU_ERF: + defines.push_back("OP_GEGLU_ERF"); + variant += "_geglu_erf"; + break; + case GGML_GLU_OP_GEGLU_QUICK: + defines.push_back("OP_GEGLU_QUICK"); + variant += "_geglu_quick"; + break; + default: + GGML_ABORT("Unsupported GLU op"); + } + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for GLU shader"); + } + + if (key.split) { + variant += "_split"; + } else { + defines.push_back("NO_SPLIT"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_glu, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + glu_pipelines[key] = pipeline; + return glu_pipelines[key]; + } + + webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rope_pipeline_key key = { + .type = context.dst->type, + .inplace = context.inplace, + .has_ff = (context.src2 != nullptr), + }; + + auto it = rope_pipelines.find(key); + if (it != rope_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "rope"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for ROPE shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.has_ff) { + defines.push_back("FF_FUNC"); + variant += "_ff"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rope, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rope_pipelines[key] = pipeline; + return rope_pipelines[key]; + } + + webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_soft_max_pipeline_key key = { + .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, + .has_mask = (context.src1 != nullptr), + .has_sink = (context.src2 != nullptr), + .inplace = context.inplace, + }; + + auto it = soft_max_pipelines.find(key); + if (it != soft_max_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "soft_max"; + + if (key.has_mask) { + defines.push_back("HAS_MASK"); + switch (key.mask_type) { + case GGML_TYPE_F32: + defines.push_back("MASK_F32"); + variant += "_mask_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("MASK_F16"); + variant += "_mask_f16"; + break; + default: + GGML_ABORT("Unsupported type for SOFT_MAX shader"); + } + } + + if (key.has_sink) { + defines.push_back("HAS_SINK"); + variant += "_sink"; + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_soft_max, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + soft_max_pipelines[key] = pipeline; + return soft_max_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5e16f84ddd2..fa3c492a7a5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -364,13 +364,6 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; - std::map> cpy_pipelines; // src_type, dst_type - - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - size_t memset_bytes_per_thread; }; @@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 } static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], - params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int split = (src1 != nullptr); std::vector params = { @@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here - const int has_sink = (src2 != nullptr); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); @@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), @@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - mask_type < 2 ? (uint32_t) src1->ne[2] : 0, - mask_type < 2 ? (uint32_t) src1->ne[3] : 0, + has_mask ? (uint32_t) src1->ne[2] : 0, + has_mask ? (uint32_t) src1->ne[3] : 0, *(uint32_t *) dst->op_params, // scale *(uint32_t *) &max_bias, *(uint32_t *) &n_head_log2, @@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, src0) } }; uint32_t binding_num = 1; - if (mask_type < 2) { + if (has_mask) { entries.push_back({ .binding = binding_num, .buffer = ggml_webgpu_tensor_buf(src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1), @@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, - ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -2885,139 +2918,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); -} - -static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // REGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); - - // GEGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); - - // SWIGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - - // SWIGLU_OAI - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - - // GEGLU_ERF - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - - // GEGLU_QUICK - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); -} - -static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - // f32 (no mask) - webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - - // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); - webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, - "soft_max_f32_mask_f32_sink_inplace", constants); - - // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); - webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, - "soft_max_f32_mask_f16_sink_inplace", constants); -} - static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; @@ -3183,10 +3083,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rope_pipeline(webgpu_ctx); - ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl new file mode 100644 index 00000000000..fa3bdf4e393 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -0,0 +1,81 @@ +enable f16; + +#ifdef SRC_F32 +#define SRC_TYPE f32 +#elif defined(SRC_F16) +#define SRC_TYPE f16 +#endif + +#ifdef DST_F32 +#define DST_TYPE f32 +#elif defined(DST_F16) +#define DST_TYPE f16 +#elif defined(DST_I32) +#define DST_TYPE i32 +#endif + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params{ + ne: u32, + offset_src: u32, + offset_dst: u32, + + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); +} + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 8b5cfe715e7..79a3a9597ab 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,41 +1,8 @@ import os import re -import ast import argparse -def extract_block(text, name): - pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise ValueError(f"Missing block: {name}") - return match.group(1).strip() - - -def parse_decls(decls_text): - decls = {} - for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): - decls[name.strip()] = code.strip() - return decls - - -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant - - -def replace_placeholders(shader_text, replacements): - for key, val in replacements.items(): - # Match {{KEY}} literally, where KEY is escaped - pattern = r'{{\s*' + re.escape(key) + r'\s*}}' - shader_text = re.sub(pattern, str(val), shader_text) - return shader_text - - def expand_includes(shader, input_dir): """ Replace #include "file" lines in the text with the contents of that file. @@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') -def generate_variants(fname, input_dir, output_dir, outfile): - shader_path = os.path.join(input_dir, fname) - shader_base_name = fname.split(".")[0] - - with open(shader_path, "r", encoding="utf-8") as f: - text = f.read() - - try: - variants = ast.literal_eval(extract_block(text, "VARIANTS")) - except ValueError: - write_shader(shader_base_name, text, output_dir, outfile, input_dir) - else: - try: - decls_map = parse_decls(extract_block(text, "DECLS")) - except ValueError: - decls_map = {} - try: - templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) - except ValueError: - templates_map = {} - - for fname in sorted(os.listdir(input_dir)): - if fname.endswith(".tmpl"): - tmpl_path = os.path.join(input_dir, fname) - with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: - decls = f_tmpl.read() - decls_map.update(parse_decls(decls)) - - shader_template = extract_block(text, "SHADER") - for variant in variants: - if "DECLS" in variant: - decls = variant["DECLS"] - else: - decls = [] - decls_code = "" - for key in decls: - if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") - decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: - variant = replace_repl_placeholders(variant, templates_map) - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - # second run to expand placeholders in repl_template - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - final_shader = expand_includes(final_shader, input_dir) - - if "SHADER_NAME" in variant: - output_name = variant["SHADER_NAME"] - elif "SHADER_SUFFIX" in variant: - output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] - elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) - elif "REPLS" in variant and "TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] - else: - output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile, input_dir) - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True) parser.add_argument("--output_file", required=True) - parser.add_argument("--output_dir") args = parser.parse_args() - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - with open(args.output_file, "w", encoding="utf-8") as out: out.write("// Auto-generated shader embedding\n") out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(fname, args.input_dir, args.output_dir, out) + shader_path = os.path.join(args.input_dir, fname) + shader_name = fname.replace(".wgsl", "") + + with open(shader_path, "r", encoding="utf-8") as f: + shader_code = f.read() + + write_shader(shader_name, shader_code, None, out, args.input_dir) if __name__ == "__main__": diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl new file mode 100644 index 00000000000..e6d7608cec5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -0,0 +1,155 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef OP_REGLU +fn op(a: DataType, b: DataType) -> DataType { + return max(a, 0) * b; +} +#endif + +#ifdef OP_GEGLU +const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876; +const GELU_COEF_A: DataType = 0.044715; + +fn op(a: DataType, b: DataType) -> DataType { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b; +} +#endif + +#ifdef OP_SWIGLU +fn op(a: DataType, b: DataType) -> DataType { + return a / (1.0 + exp(-a)) * b; +} +#endif +#ifdef OP_SWIGLU_OAI +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#endif +#ifdef OP_GEGLU_ERF +const p_erf: DataType = 0.3275911; +const a1_erf: DataType = 0.254829592; +const a2_erf: DataType = -0.284496736; +const a3_erf: DataType = 1.421413741; +const a4_erf: DataType = -1.453152027; +const a5_erf: DataType = 1.061405429; +const SQRT_2_INV: DataType = 0.7071067811865476; + +fn op(a: DataType, b: DataType) -> DataType { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#endif +#ifdef OP_GEGLU_QUICK +const GELU_QUICK_COEF: DataType = -1.702; + +fn op(a: DataType, b: DataType) -> DataType { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array; + +#ifdef NO_SPLIT +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> DataType { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> DataType { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} + +#else +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> DataType { + return src0[base]; +} + +fn b_value(base: u32) -> DataType { + return src1[base]; +} + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl new file mode 100644 index 00000000000..1c874e14240 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl @@ -0,0 +1,224 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array; +@group(0) @binding(1) +var src1: array; + +#ifdef INPLACE + +#ifdef FF_FUNC + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(2) +var params: Params; + +#endif + +#else + +#ifdef FF_FUNC +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#ifdef FF_FUNC +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} + +#else +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#endif +#ifdef INPLACE +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = DataType(out0); + src0[i_dst1] = DataType(out1); +} +#else +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = DataType(out0); + dst[i_dst1] = DataType(out1); +} +#endif + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per n_threads + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_imrope = params.mode == 40; + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * params.sections1) { + theta_base_mult = 1; + } else if (sector % 3 == 2 && sector < 3 * params.sections2) { + theta_base_mult = 2; + } else if (sector % 3 == 0 && sector < 3 * params.sections0) { + theta_base_mult = 0; + } else { + theta_base_mult = 3; + } + } else { + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl new file mode 100644 index 00000000000..10edf136048 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl @@ -0,0 +1,245 @@ +enable f16; + +#ifdef MASK_F32 +#define MaskType f32 +#endif +#ifdef MASK_F16 +#define MaskType f16 +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var src: array; + +#ifdef HAS_MASK +#ifdef HAS_SINK +@group(0) @binding(1) +var mask: array; +@group(0) @binding(2) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(3) +var params: Params; + +#else +@group(0) @binding(3) +var dst: array; +@group(0) @binding(4) +var params: Params; +#endif + +#else +@group(0) @binding(1) +var mask: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#else +#ifdef HAS_SINK +@group(0) @binding(1) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif + +#else +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; +@group(0) @binding(2) +var params: Params; +#endif +#endif +#endif + +#ifdef INPLACE +fn inter_value(i: u32) -> f32 { + return src[i]; +} +fn update(i: u32, val: f32) { + src[i] = val; +} + +#else +fn inter_value(i: u32) -> f32 { + return dst[i]; +} +fn update(i: u32, val: f32) { + dst[i] = val; +} +#endif + +#ifdef HAS_MASK +fn mask_val(i: u32) -> f32 { + return f32(mask[i]); +} + +#else +fn mask_val(i: u32) -> f32 { + return 0.0; +} +#endif + +#ifdef HAS_SINK +fn lower_max_bound(i2: u32) -> f32 { + return sinks[params.offset_sinks + i2]; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val + exp(sinks[params.offset_sinks + i2] - max_val); +} +#else +fn lower_max_bound(i2: u32) -> f32 { + return -1e30; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val; +} +#endif + +const CACHE_SIZE: u32 = 16; +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + let head = f32(i2); + let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + + var cache: array; + + var max_val = lower_max_bound(i2); + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); + max_val = max(max_val, val); + if (col < CACHE_SIZE) { + cache[col] = val; + } + col += WG_SIZE; + } + + scratch[lid.x] = max_val; + workgroupBarrier(); + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); + } + offset = offset / 2; + workgroupBarrier(); + } + let row_max = scratch[0]; + workgroupBarrier(); + + var sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), + cache[col], col < CACHE_SIZE); + let ex = exp(val - row_max); + sum += ex; + if (col < CACHE_SIZE) { + cache[col] = ex; + } else { + update(i_dst_row + col, ex); + } + col += WG_SIZE; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + offset = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + let row_sum = add_sinks(scratch[0], i2, row_max); + + let sum_recip = 1.0 / row_sum; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + col += WG_SIZE; + } +} + From 78f54d15d80aded8a603e12fb066539acdb32f49 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 31 Mar 2026 22:38:24 -0700 Subject: [PATCH 070/249] ggml webgpu: quantized buffers to u32 + wider browser/device support (llama/21046) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 10 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 17 +- .../wgsl-shaders/common_decls.tmpl | 24 +++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 81 ++++++-- .../wgsl-shaders/mul_mat_decls.tmpl | 196 +++++++----------- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 103 ++++----- 6 files changed, 207 insertions(+), 224 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 97863f40412..a194ce84e25 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1219,9 +1219,8 @@ class ggml_webgpu_shader_lib { defines.push_back("BYTE_HELPERS"); defines.push_back("MUL_ACC_" + type_upper); - - // For fast path we always dequantize from f16 inside the shader - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); break; } } @@ -1334,9 +1333,8 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); defines.push_back("INIT_SRC0_SHMEM_" + type_upper); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - // Use f16 inside the shader for quantized types - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); variant += std::string("_") + src0_name; break; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index fa3c492a7a5..1aa15b0507c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #define WEBGPU_NUM_PARAM_BUFS 96u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 +#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) @@ -171,6 +171,7 @@ struct webgpu_buf_pool { // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; + lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks wgpu::Buffer dev_buf; ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); @@ -507,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; while (blocking_wait) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0); + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6); if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); @@ -728,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); std::vector commands = { command }; std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; - ggml_backend_webgpu_wait(ctx, sub); } /** End WebGPU Actions */ @@ -2694,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, // memset the remaining bytes ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); - } else { - // wait for WriteBuffer to complete - buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); } WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 9a5b18ebc07..feb0bca3f84 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { } #endif +#ifdef U32_DEQUANT_HELPERS +fn load_src0_u16_at(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_src0_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = src0[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = src0[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_src0_f16_at(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_src0_u16_at(byte_offset)); + return f16(packed[0]); +} +#endif + #ifdef Q4_0_T struct q4_0 { d: f16, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index b6822161464..8b76cecba91 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 +#elif defined(KV_Q4_0) || defined(KV_Q8_0) +#define KV_TYPE u32 #else #define KV_TYPE f16 #endif @@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix; #define NQ 16 // Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights #define F16_PER_BLOCK 9 +#define BLOCK_SIZE_BYTES 18u #define WEIGHTS_PER_F16 4 #elif defined(KV_Q8_0) #define NQ 8 // Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights #define F16_PER_BLOCK 17 +#define BLOCK_SIZE_BYTES 34u #define WEIGHTS_PER_F16 2 #endif #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) @@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#if defined(KV_Q4_0) || defined(KV_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} +#endif + struct Params { offset_q: u32, offset_k: u32, @@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; @@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index de60ebbcf2b..eb228537bad 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 20u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_0 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 22u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_src0_f16_at(block_byte_base); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_1 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 24u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 34u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread @@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 36u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block @@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q2_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 42u; +const BLOCK_SIZE_BYTES = 84u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { // Use standard thread layout instead of lane/row_group @@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 40u]; - let dmin = src0[scale_idx + 41u]; + let d = load_src0_f16_at(block_byte_base + 80u); + let dmin = load_src0_f16_at(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_0 = src0[scale_idx + 2u * (is / 4u)]; - let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; - let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q3_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 55u; +const BLOCK_SIZE_BYTES = 110u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 54u]; + let d = load_src0_f16_at(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - let scale_0 = src0[scale_idx + 48u + (2u*i)]; - let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - let hmask_0 = src0[scale_idx + (2u*i)]; - let hmask_1 = src0[scale_idx + (2u*i) + 1u]; - hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - let qs_0 = src0[scale_idx + 16u + (2u*i)]; - let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; - qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 72u; +const BLOCK_SIZE_BYTES = 144u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_src0_f16_at(block_byte_base); + let dmin = load_src0_f16_at(block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); } // Map k_in_block to loop structure: @@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 88u; +const BLOCK_SIZE_BYTES = 176u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_src0_f16_at(block_byte_base); + let dmin = load_src0_f16_at(block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); } // The original loop processes elements in groups of 64 @@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; - let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; - let qh_packed = bitcast(vec2(qh_0, qh_1)); + let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; +const BLOCK_SIZE_BYTES = 210u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let half = k_in_block / 128u; let pos_in_half = k_in_block % 128u; @@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13_word = ql13_flat / 4u; - let ql13 = bitcast(vec2( - src0[scale_idx + 2u * ql13_word], - src0[scale_idx + 2u * ql13_word + 1u] - )); - let ql13_b = get_byte(ql13, ql13_flat % 4u); + let ql13 = load_src0_u32_at(block_byte_base + ql13_flat); + let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24_word = ql24_flat / 4u; - let ql24 = bitcast(vec2( - src0[scale_idx + 2u * ql24_word], - src0[scale_idx + 2u * ql24_word + 1u] - )); - let ql24_b = get_byte(ql24, ql24_flat % 4u); + let ql24 = load_src0_u32_at(block_byte_base + ql24_flat); + let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh_word = qh_flat / 4u; - let qh = bitcast(vec2( - src0[scale_idx + 64u + 2u * qh_word], - src0[scale_idx + 64u + 2u * qh_word + 1u] - )); - let qh_b = get_byte(qh, qh_flat % 4u); + let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat); + let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); @@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc_word = sc_idx / 4u; - let sc = bitcast(vec2( - src0[scale_idx + 96u + 2u * sc_word], - src0[scale_idx + 96u + 2u * sc_word + 1u] - )); - let sc_val = get_byte_i32(sc, sc_idx % 4u); - - let d = src0[scale_idx + 104u]; + let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx); + let sc_val = get_byte_i32(sc, 0u); + + let d = load_src0_f16_at(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 94f4bae11f4..6525f23bdfc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q4_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 18u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); + let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q4_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 20u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 10u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = f32(src0[scale_idx + 1u]); + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q5_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 22u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 11u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = f32(load_src0_f16_at(block_byte_base)); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q5_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 24u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 12u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = f32(load_src0_f16_at(block_byte_base)); + let m = load_src0_f16_at(block_byte_base + 2u); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q8_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 34u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 17u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); + let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q8_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 36u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 18u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; - -fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { - let aligned = byte_offset & ~3u; - let idx = bbase + aligned / 2u; - return bitcast(vec2(src0[idx], src0[idx + 1u])); -} +const BLOCK_SIZE_BYTES = 210u; fn byte_of(v: u32, b: u32) -> u32 { return (v >> (b * 8u)) & 0xFFu; @@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { var local_sum = 0.0; for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; + let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d_raw = load_u32_at(bbase, 208u); - let d = f32(bitcast>(d_raw)[0]); + let d = f32(load_src0_f16_at(bbase + 208u)); - let ql1_u32 = load_u32_at(bbase, q_offset_l); - let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); - let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); - let sc_u32_0 = load_u32_at(bbase, sc_base_byte); - let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); + let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); + let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); + let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); + let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); + let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); From 933bd1f79c925f9f1d563854dd1fb4e40c288568 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 1 Apr 2026 07:07:24 +0000 Subject: [PATCH 071/249] CUDA: Add Flash Attention Support for Head Dimension 512 (llama/20998) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * flash attention support for head dimension 512 added * FA D=512 - match 576 configs, limit ncols2, revert vec cap * fix HIP tile kernel build for D=512 * fix HIP tile kernel occupancy for D=512 on AMD * Apply suggestions from code review Co-authored-by: Johannes Gäßler * fix tile FA compilation --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 30 ++++++++++++++- ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 37 +++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 11 ++++-- ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 1 + .../fattn-tile-instance-dkq512-dv512.cu | 5 +++ .../template-instances/generate_cu_files.py | 4 +- 14 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index fff70c8eb89..b613ae61fb8 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); @@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { NO_DEVICE_CODE; return; } @@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); + // The number of viable configurations for Deepseek is very limited: extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2b..25b16e83cac 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23d..26721cc4c7d 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) @@ -767,7 +785,7 @@ static __global__ void flash_attn_tile( #ifdef GGML_USE_WMMA_FATTN (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || #endif // GGML_USE_WMMA_FATTN - (use_logit_softcap && !(DV == 128 || DV == 256)) + (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) ) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DV == 512) { + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DV <= 256) { + if constexpr (DKQ <= 512) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - - launch_fattn_tile_switch_ncols1(ctx, dst); - return; } GGML_ABORT("fatal error"); } @@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a25a890db6d..a21c5361048 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 512: + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); + break; case 576: { // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); @@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: - if (V->ne[0] != K->ne[0]) { + case 512: + if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } break; @@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc16829021f..22d383173f3 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 517993cb068..d2415bfa957 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 97b19c67ade..8eec1d74e29 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939e4..84b674cd05a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 989626dfa5e..3475dfea08a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b4141..5906398db91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 173de7aac7d..684cd25ce0d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6de..4bc60d62f91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu new file mode 100644 index 00000000000..7c61d8d2ecd --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3b5ab12fc40..b7b5832293e 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -83,6 +83,8 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue + if head_size_kq == 512 and ncols2 not in (4, 8): + continue if head_size_kq != 576 and ncols2 in (16, 32): continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): From 1b95f84550d32e59e9bbef4eaab0e0ce9240bf90 Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Wed, 1 Apr 2026 13:10:03 +0500 Subject: [PATCH 072/249] ggml-cpu: fix fallback for RVV kernels without zvfh (llama/21157) * ggml-cpu: refactor sgemm; fix rvv checks * ggml-cpu: refactor rvv kernels; set zvfbfwma default to off --- ggml/CMakeLists.txt | 19 +- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 147 +++++++------ ggml/src/ggml-cpu/vec.h | 292 +++++++++++++------------- 3 files changed, 239 insertions(+), 219 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a739cca4218..ab558438e95 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -166,15 +166,16 @@ if (NOT MSVC) option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF) option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF) endif() -option(GGML_LASX "ggml: enable lasx" ON) -option(GGML_LSX "ggml: enable lsx" ON) -option(GGML_RVV "ggml: enable rvv" ON) -option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) -option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) -option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) -option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON) -option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) -option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) +option(GGML_LASX "ggml: enable lasx" ON) +option(GGML_LSX "ggml: enable lsx" ON) +option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) +option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) +option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) +option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON) +option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) +option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 63ceb635dea..34e320e2f50 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { + return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { + return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { + return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { + return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) -template <> -inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } -inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { +template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); } -inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { - return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); -} -inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { - return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); -} -inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { - return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); -} -inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { - return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) -inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } +template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) { + return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -272,7 +277,7 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) inline float hsum(vfloat32m1_t x) { return __riscv_vfmv_f_s_f32m1_f32( __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1())); @@ -379,6 +384,21 @@ template <> inline __m256bh load(const float *p) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t load(const float *p) { + return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t load(const float *p) { + return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t load(const float *p) { + return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t load(const float *p) { + return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); @@ -392,18 +412,6 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) { template <> inline vfloat16m4_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); } -template <> inline vfloat32m1_t load(const float *p) { - return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); -} -template <> inline vfloat32m2_t load(const float *p) { - return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); -} -template <> inline vfloat32m4_t load(const float *p) { - return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); -} -template <> inline vfloat32m8_t load(const float *p) { - return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) @@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) { template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) { return __riscv_vle16_v_bf16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); } +template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); +} #endif -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) template T set_zero(); -template <> inline vfloat16mf2_t set_zero() { - return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2()); -} -template <> inline vfloat16m1_t set_zero() { - return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1()); -} -template <> inline vfloat16m2_t set_zero() { - return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2()); -} -template <> inline vfloat16m4_t set_zero() { - return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4()); -} template <> inline vfloat32m1_t set_zero() { return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1()); } @@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() { #if defined(__riscv_v_intrinsic) template size_t vlmax() { - if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } + if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m2(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m4(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m8(); } + #if defined (__riscv_zvfh) + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + #endif + #if defined (__riscv_zvfbfwma) + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + #endif return 0; } #endif @@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; -#elif defined(__riscv_zvfh) +#elif defined(__riscv_v_intrinsic) #if LMUL == 1 tinyBLAS_RVV tb{ params, k, (const float *)A, lda, @@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return true; } #elif defined(__riscv_zvfbfwma) - #if LMUL == 1 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #elif LMUL == 2 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #else // LMUL = 4 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #endif - return tb.matmul(m, n); + if (Btype == GGML_TYPE_BF16) { + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); + } #endif return false; } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3198b33b509..a0375a28de0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG const int ggml_f16_epr = sve_register_length / 16; // running when 16 const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t sum_00 = svdup_n_f16(0.0f); svfloat16_t sum_01 = svdup_n_f16(0.0f); @@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + np = n; + #elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 along the row dimension + for (int i = 0; i < np; i += step) { + vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); + vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); + vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); + vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); + vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); + + vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); + vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); + vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); + vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); + vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); + } - #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - size_t vl = __riscv_vsetvlmax_e32m4(); - - // initialize accumulators to all zeroes - vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - - // calculate step size - const size_t epr = __riscv_vsetvlmax_e16m2(); - const size_t step = epr * 2; - const int np = (n & ~(step - 1)); - - // unroll by 2 along the row dimension - for (int i = 0; i < np; i += step) { - vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); - vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); - vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); - vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); - vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); - - vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); - vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); - vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); - vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); - vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); - } - - vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); - vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); - - // leftovers - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m2(n - i); - vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); - vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); - vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); + vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); + vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); - vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); - vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); - } + // leftovers + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); - // reduce - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), - __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), - __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); - vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( - acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), - __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), - __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); - vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( - acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); - sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); + vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); + } + // reduce + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), + __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), + __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); + vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( + acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), + __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), + __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); + vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( + acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); + sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + np = n; + #else + const int np = 0; + #endif #else const int np = (n & ~(GGML_F16_STEP - 1)); @@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { GGML_F16_VEC_REDUCE(sumf[k], sum[k]); } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); - } - } #endif #else - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); } } -#endif for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { s[i] = (float)sumf[i]; @@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, svst1_f16(pg, (__fp16 *)(y + np2), hy); } np = n; -#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - int np = (n & ~(step - 1)); - - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); - - vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } +#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic + #if defined (__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } - np = n; + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, } } #else + // scalar path const int np = 0; #endif - // leftovers + // scalar and leftovers for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } @@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float const int ggml_f16_step = 2 * ggml_f16_epr; GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t ay1, ay2; for (int i = 0; i < np; i += ggml_f16_step) { @@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float svfloat16_t out = svmul_f16_m(pg, hy, vx); svst1_f16(pg, (__fp16 *)(y + np), out); } -#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - const int np = (n & ~(step - 1)); + np = n; +#elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } #else - // scalar - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); } -#endif } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } From 5c5b88eb779cbd37a32c209c2034e6b56b55c4fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 11:10:25 +0300 Subject: [PATCH 073/249] ggml : fix RWKV ops thread assignment (llama/21226) --- ggml/src/ggml-cpu/ggml-cpu.c | 6 +++++- ggml/src/ggml-cpu/ops.cpp | 30 +++++++++--------------------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index df17cc55300..7486acc2b5d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: { - n_tasks = n_threads; + const int64_t n_heads = node->src[1]->ne[1]; + n_tasks = MIN(n_threads, n_heads); } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d950972c83e..765ce07f06c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * r = (float *) dst->src[0]->data; float * w = (float *) dst->src[1]->data; From 1971a362dc008312762ed208cd0296bc23717901 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 1 Apr 2026 10:21:20 +0200 Subject: [PATCH 074/249] CUDA/HIP: Fix kernel slection for mmvq mmid kernel to align host selection with device launch bounds (llama/21238) The conditions cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE and cc >= GGML_CUDA_CC_TURING match all non-nvidia devices. This causes us to attempt to launch the kernel for batch sizes with larger configurations than our launch bounds on HIP devices. This pr fixes the conditionals in get_mmvq_mmid_max_batch. Fixes #21191 --- ggml/src/ggml-cuda/mmvq.cu | 43 ++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8d80d1dd9a7..07b10167bc4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type // Host function: returns the max batch size for the current arch+type at runtime. int get_mmvq_mmid_max_batch(ggml_type type, int cc) { // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. - if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return MMVQ_MAX_BATCH_SIZE; - } - if (cc >= GGML_CUDA_CC_TURING) { - return get_mmvq_mmid_max_batch_turing_plus(type); - } if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } return get_mmvq_mmid_max_batch_pascal_older(type); } + // AMD - if (GGML_CUDA_CC_IS_RDNA4(cc)) { - return get_mmvq_mmid_max_batch_rdna4(type); - } - if (GGML_CUDA_CC_IS_RDNA3(cc)) { - return get_mmvq_mmid_max_batch_rdna3(type); - } - if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { - return get_mmvq_mmid_max_batch_rdna1_rdna2(type); - } - if (GGML_CUDA_CC_IS_CDNA(cc)) { - return get_mmvq_mmid_max_batch_cdna(type); - } - if (GGML_CUDA_CC_IS_GCN(cc)) { - return get_mmvq_mmid_max_batch_gcn(type); + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } } return MMVQ_MAX_BATCH_SIZE; } From ace95aac6b32f6e0e57a45789d3ec82c8c89e9ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 16:01:45 +0300 Subject: [PATCH 075/249] ggml : bump version to 0.9.10 (ggml/1454) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index ab558438e95..2ffc3b391fe 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 9) +set(GGML_VERSION_PATCH 10) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 981195be5aa6ec359b2e536b4194c0f7a7a3ee20 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Wed, 1 Apr 2026 03:04:58 -0700 Subject: [PATCH 076/249] ggml-cuda: Add generic NVFP4 MMQ kernel (llama/21074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced NVFP4 generic MMQ kernel * Added extra FP8 guard, hope to solve ci HIP failure * Rename tiles and use HIP_FP8_AVAILABLE * Removed remaning FP8 straggler and added const int * Const * Removed DECL_MMQ_CASE artifact * Removed newline * Removed space after else * Changed HIP FP8 NVFP4 conversion gate * Added new line to bottom of mmq.cu 270 * Removed extra spaces * Removed single space in front of else on line 814 * Added NVFP4 to generate cu script so HIP can see it, further tightened logic * Include generated mmq-instance-nvfp4.cu * Added NVFP4 mmq to HIP Check ignore list * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Added function name ending for end if Co-authored-by: Johannes Gäßler * Added function names to closing endif Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 27 ++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 2 - ggml/src/ggml-cuda/mmq.cu | 5 +- ggml/src/ggml-cuda/mmq.cuh | 89 +++++++++++++++++-- .../template-instances/generate_cu_files.py | 2 +- .../template-instances/mmq-instance-nvfp4.cu | 5 ++ 6 files changed, 112 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7d7f20af3a0..9affe023403 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { } static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { -#ifdef FP8_AVAILABLE - const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. -#if defined(GGML_USE_HIP) && defined(CDNA3) - // ROCm dose not support fp8 in software on devices with fp8 hardware, +#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 + // ROCm does not support fp8 in software on devices with fp8 hardware, // but CDNA3 supports only e4m3_fnuz (no inf). + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; #else +#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); -#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP) return static_cast(xf) / 2; #else - NO_DEVICE_CODE; -#endif // FP8_AVAILABLE + if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f + return 0.0f; + } + const int exp = (x >> 3) & 0xF; + const int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return static_cast(raw / 2); +#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) +#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 } __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d1239b1c5f7..75b62129ade 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: -#ifdef FP8_AVAILABLE case GGML_TYPE_NVFP4: -#endif // FP8_AVAILABLE case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 9a69f41d159..27b4145ac9a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con case GGML_TYPE_MXFP4: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q2_K: mul_mat_q_case(ctx, args, stream); break; @@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; - } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 255e59f6fc6..51e8dad4ce7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_NVFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: @@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; @@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); +static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); + static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { @@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } + +template +static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kb0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4; + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = threadIdx.x % threads_per_row; + const int row_in_warp = threadIdx.x / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx; + const uint32_t * __restrict__ src_qs = reinterpret_cast(bxi->qs); + const int kqs = 16 * kbx; + const int ksc = 4 * kbx; + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4); + const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y; + x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#else + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y; + x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -// Used for Q3_K, IQ2_S, and IQ2_XS +// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -3261,6 +3327,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; @@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_NVFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index b7b5832293e..40d51f93fa4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -35,7 +35,7 @@ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu new file mode 100644 index 00000000000..2cb140d35a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_NVFP4); From fab70d287e977d607247218e7e6e85b7f093adf3 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 1 Apr 2026 18:54:15 +0800 Subject: [PATCH 077/249] sycl : support nvfp4 type in mul_mat (llama/21227) --- ggml/src/ggml-sycl/common.hpp | 7 ++ ggml/src/ggml-sycl/convert.cpp | 18 +++++ ggml/src/ggml-sycl/dequantize.hpp | 32 +++++++++ ggml/src/ggml-sycl/mmvq.cpp | 22 +++++- ggml/src/ggml-sycl/type.hpp | 112 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 42 +++++++++++ 6 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-sycl/type.hpp diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fcb0db99c6b..fd84c917853 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -23,6 +23,7 @@ #include "ggml-impl.h" #include "ggml-sycl.h" #include "presets.hpp" +#include "type.hpp" #include "sycl_hw.hpp" namespace syclexp = sycl::ext::oneapi::experimental; @@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) { return val; } +static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) { + const uint32_t bits = x * (x != 0x7F && x != 0xFF); + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d17aca2cac4..d7f60cbc9ea 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t }); } +template +static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_nvfp4(vx, y, k); + }); +} + + template static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } @@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index da2a605daa8..3272724f41b 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr } } + +template +static void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_sycl_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_sycl_cast(d * kvalues_mxfp4[q >> 4]); +} + + #endif // GGML_SYCL_DEQUANTIZE_HPP diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 316aa0d0fb5..5abc50fabfe 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_MXFP4: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/type.hpp b/ggml/src/ggml-sycl/type.hpp new file mode 100644 index 00000000000..d7ff89d7d42 --- /dev/null +++ b/ggml/src/ggml-sycl/type.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +inline uint8_t float_to_e4m3(float f) +{ + if (sycl::isnan(f)) { + return 0x7F; // Canonical NaN (positive) + } + + uint32_t bits = sycl::bit_cast(f); + uint32_t sign = (bits >> 31) & 0x1u; + uint32_t exp = (bits >> 23) & 0xFFu; + uint32_t mant = bits & 0x7FFFFFu; + + // Zero + if (exp == 0 && mant == 0) { + return static_cast(sign << 7); + } + + // Extract biased exponent and mantissa for FP8 + int e = static_cast(exp) - 127; // true exponent (IEEE bias 127) + uint32_t m = mant; + + // Handle very large values → NaN (NVIDIA behavior for E4M3) + if (e > 7) { // max exponent for E4M3 is 7 (biased 14) + return static_cast((sign << 7) | 0x7F); + } + + // Handle subnormals and normal numbers + if (e < -6) { // smallest normal exponent is -6 + // Subnormal in FP8: shift mantissa right + int shift = -6 - e; + m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position + if (shift > 23) m = 0; + } else { + // Normal number: adjust exponent bias from 127 to 7 + int new_exp = e + 7; + m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1) + m |= (static_cast(new_exp) << 3); + } + + // Round-to-nearest-even (simple guard + round bit) + // For better accuracy you can add sticky bit, but this is sufficient for most use cases + uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits + if (round_bit) { + m += 1; + // Carry into exponent if mantissa overflows + if ((m & 0x8u) != 0) { + m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling + // If exponent overflows after carry → NaN + if ((m >> 3) > 14) { + return static_cast((sign << 7) | 0x7F); + } + } + } + + uint8_t result = static_cast((sign << 7) | (m & 0x7F)); + return result; +} + +inline float e4m3_to_float(uint8_t x) +{ + if (x == 0) return 0.0f; + + uint8_t sign = (x >> 7) & 0x1u; + uint8_t exp = (x >> 3) & 0xFu; + uint8_t mant = x & 0x7u; + + // NaN (NVIDIA uses 0x7F / 0xFF as NaN) + if (exp == 0xF && mant != 0) { + return std::numeric_limits::quiet_NaN(); + } + if (exp == 0xF) { // 0x7F or 0xFF treated as NaN + return std::numeric_limits::quiet_NaN(); + } + + float val; + + if (exp == 0) { + // Subnormal + val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f); + } else { + // Normal: implicit leading 1 + bias 7 + val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast(exp) - 7.0f); + } + + return sign ? -val : val; +} + +// The actual type definition +struct __nv_fp8_e4m3 { + uint8_t raw; + + __nv_fp8_e4m3() = default; + + explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {} + explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast(h))) {} + + operator float() const { return e4m3_to_float(raw); } + operator sycl::half() const { return static_cast(static_cast(*this)); } + + // Allow direct access for vector loads/stores + operator uint8_t&() { return raw; } + operator uint8_t() const { return raw; } +}; + +using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>; +using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>; + diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 9a267d85a0c..eab9850aed7 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -15,6 +15,7 @@ #include "dpct/helper.hpp" #include "ggml.h" +#include "type.hpp" #include "quants.hpp" typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, @@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) { return x32; } +static __dpct_inline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq, return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & iqs) { + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0]; + sum += d * float(sumi); + } + + return sum; +} static __dpct_inline__ float vec_dot_q5_0_q8_1(const void *__restrict__ vbq, From 9a40dd9365ac55c16a27200e8db3873dbb4c7cbd Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 1 Apr 2026 21:13:08 +0530 Subject: [PATCH 078/249] hexagon: improve RMS_NORM and DIV accuracy (llama/21251) * hexagon-rms_norm: fix RMS_NORM for non-aligned tensor sizes Co-authored-by: Krishna Sridhar * hexagon-div: perform DIV in fp16 domain for lower dsp archs --------- Co-authored-by: Krishna Sridhar --- ggml/src/ggml-hexagon/htp/hvx-div.h | 86 ++++++++++++++++++++------- ggml/src/ggml-hexagon/htp/unary-ops.c | 41 ++++++++++--- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 05cefea039f..53ee304e749 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -16,8 +16,10 @@ #if __HVX_ARCH__ < 79 #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)) #else #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b) #endif // Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. @@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX return res; } -#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ - do { \ - dst_type * restrict vdst = (dst_type *) dst; \ - src_type * restrict vsrc = (src_type *) src; \ - HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ - \ - const uint32_t nvec = n / VLEN_FP16; \ - const uint32_t nloe = n % VLEN_FP16; \ - \ - uint32_t i = 0; \ - \ - _Pragma("unroll(4)") \ - for (; i < nvec; i++) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vdst[i] = res; \ - } \ - if (nloe) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ - } \ +// Variant for =v79 +static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // For older architectures, use f16 reciprocal to avoid NaN/-inf issues + HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask); + return HVX_OP_MUL_F16(vec1, vec2_inv); +#else + return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0); +#endif +} + #define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ \ const uint32_t nvec = n / VLEN_FP16; \ @@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vdst[i] = res; \ } \ if (nloe) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ } \ } while(0) @@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32) HVX_DIV_DISPATCHER(hvx_div_f16) #undef HVX_OP_MUL_F32 +#undef HVX_OP_MUL_F16 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 3d0928d4dce..13d28317d5c 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, uint8_t * restrict pad, const int num_elems, float epsilon) { + (void)pad; + const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); - int step_of_1 = num_elems >> 5; #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); - sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + // Scale full vectors HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); - v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v2); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); } } From 82bb26fba1b4de5180009ae5a2a20537efba8ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 1 Apr 2026 21:28:19 +0200 Subject: [PATCH 079/249] CUDA: fix FA kernel selection logic (llama/21271) --- ggml/src/ggml-cuda/fattn.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a21c5361048..addf93205ef 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -340,7 +340,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } From 08108512c7c3ae2610d1e5f36c80cd7d3a753987 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 1 Apr 2026 12:54:58 -0700 Subject: [PATCH 080/249] opencl: fix leak in Adreno q8_0 path (llama/21212) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0f6628c377d..6f3fc5886d8 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -9612,6 +9612,9 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t cl_mem B_image1d; cl_mem B_sub_buffer; cl_mem S_image1d; + // for B transpose + cl_mem B_image1d_trans = nullptr; + cl_mem B_d = nullptr; cl_mem D_image1d; cl_mem D_sub_buffer; @@ -9703,9 +9706,6 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t global_work_size[2] = 1; } else { cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_mem B_image1d_trans = nullptr; - // for B transpose - cl_mem B_d = nullptr; int padding; //how many extra elements beyond multiple of 8 @@ -9800,6 +9800,12 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t CL_CHECK(clReleaseMemObject(S_image1d)); CL_CHECK(clReleaseMemObject(D_sub_buffer)); CL_CHECK(clReleaseMemObject(D_image1d)); + if (B_image1d_trans) { + CL_CHECK(clReleaseMemObject(B_image1d_trans)); + } + if (B_d) { + CL_CHECK(clReleaseMemObject(B_d)); + } #else GGML_UNUSED(backend); GGML_UNUSED(src0); From 444662bc8307fc7a5d49acde48ae32e3c51b280b Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 1 Apr 2026 17:44:02 -0700 Subject: [PATCH 081/249] hexagon : add cumsum op support (llama/21246) * hexagon : add cumsum op support * hexagon: enable dma for cumsum op * Fix line-ending --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 34 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/cumsum-ops.c | 267 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-msg.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 43 ++++ 6 files changed, 347 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/cumsum-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index dd604db4333..f91bc46552e 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2231,6 +2231,22 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2399,6 +2415,16 @@ static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_CUMSUM; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2780,6 +2806,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CUMSUM: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -3254,6 +3284,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_ssm_conv(sess, op); break; + case GGML_OP_CUMSUM: + supp = ggml_hexagon_supported_cumsum(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 6ddfe4252f5..2b60f427ada 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -33,6 +33,7 @@ add_library(${HTP_LIB} SHARED repeat-ops.c argsort-ops.c ssm-conv.c + cumsum-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c new file mode 100644 index 00000000000..ce51555a7fd --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" +#include "hex-dma.h" + +#define htp_cumsum_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_cumsum_context { + struct htp_ops_context * octx; + size_t src_row_size; + size_t dst_row_size; + size_t src_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t rows_per_thread; + uint32_t total_rows; +}; + +#define htp_cumsum_preamble \ + struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \ + struct htp_ops_context * octx = cctx->octx; \ + htp_cumsum_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX prefix scan helpers +// --------------------------------------------------------------------------- + +#if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} +#else +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} +#endif // __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64)); + v = hvx_cumsum_vadd(v, carry_in); + + return v; +} + +static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) { + return hvx_vec_repl4(Q6_V_vror_VR(v, 124)); +} + +static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) { + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + HVX_Vector carry = Q6_V_vsplat_R(0); + + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32)); + v = hvx_prefix_scan_f32(v, carry); + hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v); + carry = hvx_splat_last_f32(v); + } + + if (nloe) { + float acc = hvx_vec_get_f32(carry); + const float * src_tail = src + nvec * VLEN_FP32; + float * dst_tail = dst + nvec * VLEN_FP32; + for (uint32_t i = 0; i < nloe; i++) { + acc += src_tail[i]; + dst_tail[i] = acc; + } + } +} + +// --------------------------------------------------------------------------- +// Per thread worker: Double-buffered DMA +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + if (ir0 >= ir1) { + return; + } + + const size_t src_row_size = cctx->src_row_size; + const size_t dst_row_size = cctx->dst_row_size; + const size_t src_row_size_aligned = cctx->src_row_size_aligned; + const size_t dst_row_size_aligned = cctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2); + + for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) { + // Dummy dst writeback to establish queue ordering + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned), + src_data + (ir * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + + for (uint32_t ir = ir0; ir < ir1; ir++) { + float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src; + float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst; + + hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00); + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < ir1) { + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + for (uint32_t ir = ir0; ir < ir1; ir++) { + const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size); + float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size); + hvx_cumsum_row_f32(src_row, dst_row, ne00); + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_cumsum_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_rows); + + const size_t src_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 2 ping-pong buffers per thread for src and dst + const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned); + + octx->src0_spad.size_per_thread = src_row_size_aligned * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + struct htp_cumsum_context cctx = { + .octx = octx, + .src_row_size = src_row_size, + .dst_row_size = dst_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .rows_per_thread = (total_rows + n_threads - 1) / n_threads, + .total_rows = total_rows, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_cumsum(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_cumsum_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 391148be0e9..df0ea7ccbd6 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -75,6 +75,7 @@ enum htp_op { HTP_OP_SUM_ROWS, HTP_OP_SSM_CONV, HTP_OP_REPEAT, + HTP_OP_CUMSUM, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index f643fdc340d..d35decaac20 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -60,5 +60,6 @@ int op_cpy(struct htp_ops_context * octx); int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 49f34b5f7d1..6f37bf9d4b8 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -860,6 +860,41 @@ static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_cumsum(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1474,6 +1509,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_ssm_conv_req(ctx, &req, bufs); break; + case HTP_OP_CUMSUM: + if (n_bufs != 2) { + FARF(ERROR, "Bad cumsum-req buffer list"); + continue; + } + proc_cumsum_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; From 514eabc1e5c67a32d2cc5990bf729af0f9802be1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Apr 2026 10:37:26 +0300 Subject: [PATCH 082/249] ggml : bump version to 0.9.11 (ggml/1456) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 2ffc3b391fe..5834e544b48 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 10) +set(GGML_VERSION_PATCH 11) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 7f6c0ac20f09ed85a3b00c4bb0665a2a091ed770 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Thu, 2 Apr 2026 15:08:32 +0800 Subject: [PATCH 083/249] sycl : fix llama_kv_cache hang when kv_cache is huge: 5GB (llama/21283) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 456b1699fa3..28be4939784 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -569,9 +569,15 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, SYCL_CHECK( CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); - SYCL_CHECK(CHECK_TRY_ERROR((*stream) - .memset(ctx->dev_ptr, value, buffer->size) - .wait())); + constexpr size_t MAX_CHUNK = 2ULL << 30; // 2 GiB + for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) { + size_t chunk = std::min(buffer->size - off, MAX_CHUNK); + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memset(static_cast(ctx->dev_ptr) + off, value, chunk) + .wait() + )); + } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ From c5a5e6528ec6002cd1d84f7a11c42255f4550044 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 2 Apr 2026 10:40:42 -0700 Subject: [PATCH 084/249] ggml-webgpu: add vectorized flash attention (llama/20709) * naive vectorized version * add vectorized flash attention * update vec version * remove unused path and shader * remove unused helper functions * add comments * remove pad path * ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization * change back to vec4 * enable multi split * enable vec path when: - Q->ne[1] < 20 - Q->ne[0] % 32 == 0 - V->ne[0] % 4 == 0 - K->type == f16 * update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select * enable vec path for q4 and q8 * flash-attn vec nwg=1 fast path (skip tmp/reduce staging) * use packed f16 K loads in flash-attn vec split * use packed f16 K loads in flash-attn vec split on host side * tune flash-attn vec f16 VEC_NE by head dim * cleanup * cleanup * keep host side clean * cleanup host side * change back to original host wait/submit behavior * formatting * reverted param-buffer pool r ecfactor * add helper functions * ggml-webgpu: move flash-attn vec pipeline caching back into shader lib * ggml-webgpu: remove duplicate functions * ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation * ggml-webgpu: revert unrelated change * ggml-webgpu: revert deleted comment * disable uniformity check * remove unnecessary change * Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl * Update ggml/src/ggml-webgpu/ggml-webgpu.cpp --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 230 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 323 +++++++- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 105 +++ .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 78 ++ .../wgsl-shaders/flash_attn_vec_split.wgsl | 729 ++++++++++++++++++ 5 files changed, 1412 insertions(+), 53 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index a194ce84e25..1c56c689312 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + std::shared_ptr decisions; +}; + struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t block_size; uint32_t tokens_per_wg; @@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; + bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; } }; @@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; @@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions { uint32_t wg_size = 0; }; +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + // Head-dim specializations used by the tuned vec f16 path. + switch (key.head_dim_qk) { + case 64: return 2u; + case 96: return 4u; + case 128: return 1u; + case 192: return 2u; + case 576: return 2u; + default: return 1u; + } +} + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { + uint32_t head_dim_v; + uint32_t wg_size; +}; + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.wg_size); + return seed; + } +}; + +inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; +} + +struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + +struct ggml_webgpu_flash_attn_blk_pipeline_key { + uint32_t q_tile; + uint32_t kv_tile; + + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { + return q_tile == other.q_tile && kv_tile == other.kv_tile; + } +}; + +struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_tile); + ggml_webgpu_hash_combine(seed, key.kv_tile); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_blk_shader_lib_context { + ggml_webgpu_flash_attn_blk_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); + variant += std::string("_qt") + std::to_string(context.key.q_tile); + + defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); + variant += std::string("_kvt") + std::to_string(context.key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_reduce_pipelines; + std::unordered_map + flash_attn_blk_pipelines; std::unordered_map @@ -1673,24 +1798,8 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - - bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % context.sg_mat_n == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = context.src1->type, - .head_dim_qk = (uint32_t) context.src0->ne[0], - .head_dim_v = (uint32_t) context.src2->ne[0], - .kv_direct = kv_direct, - .has_mask = has_mask, - .has_sinks = has_sinks, - .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, - }; - - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { + auto it = flash_attn_pipelines.find(context.key); if (it != flash_attn_pipelines.end()) { return it->second; } @@ -1698,7 +1807,7 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "flash_attn"; - switch (key.kv_type) { + switch (context.key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(key.kv_type); + variant += std::string("_") + ggml_type_name(context.key.kv_type); - if (key.has_mask) { + if (context.key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (key.has_sinks) { + if (context.key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (key.uses_logit_softcap) { + if (context.key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (key.kv_direct) { + if (context.key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (context.key.has_mask && context.key.use_vec) { + defines.push_back("BLK"); + variant += "_blk"; + } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; + uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, - context.wg_mem_limit_bytes, context.max_subgroup_size }), + std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (key.kv_direct) { + if (context.key.use_vec) { + q_tile = 1; + kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + } + if (context.key.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } @@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + uint32_t wg_size = 0; + if (context.key.use_vec) { + wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + } else { + wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); auto decisions = std::make_shared(); decisions->q_tile = q_tile; decisions->kv_tile = kv_tile; decisions->wg_size = wg_size; + pipeline.context = decisions; + flash_attn_pipelines[context.key] = pipeline; + return flash_attn_pipelines[context.key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + auto it = flash_attn_blk_pipelines.find(context.key); + if (it != flash_attn_blk_pipelines.end()) { + return it->second; + } + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_blk_pipelines[context.key] = pipeline; + return flash_attn_blk_pipelines[context.key]; + } + + webgpu_pipeline get_flash_attn_vec_reduce_pipeline( + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + auto it = flash_attn_vec_reduce_pipelines.find(context.key); + if (it != flash_attn_vec_reduce_pipelines.end()) { + return it->second; + } - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - flash_attn_pipelines[key] = pipeline; - return flash_attn_pipelines[key]; + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_vec_reduce_pipelines[context.key] = pipeline; + return flash_attn_vec_reduce_pipelines[context.key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1aa15b0507c..e53281bfbbd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < params_bufs_list.size(); i++) { ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } - #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { @@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } -#ifndef __EMSCRIPTEN__ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = Q, - .src1 = K, - .src2 = V, - .src3 = mask, - .src4 = sinks, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + + const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && + (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); + const uint32_t vec_nwg_cap = + std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .use_vec = use_vec, + }; + + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + + wgpu::Buffer blk_buf = {}; + uint64_t blk_size_bytes = 0; + uint32_t blk_nblk0 = 0; + uint32_t blk_nblk1 = 0; + uint32_t blk_batch_count = 0; + + if (use_vec) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } + + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (use_blk) { + GGML_ASSERT(has_mask); + + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { + .key = + { + .q_tile = decisions->q_tile, + .kv_tile = decisions->kv_tile, + }, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, + { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (use_blk) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + if (use_blk) { + split_entries.push_back( + { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); + } + split_entries.push_back( + { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, + std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = + { + .head_dim_v = (uint32_t) V->ne[0], + .wg_size = reduce_wg_size, + }, + .max_wg_size = reduce_wg_size, + }; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + } + + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + if (use_blk) { + pipelines.push_back(blk_pipeline); + params_list.push_back(std::move(blk_params)); + entries_list.push_back(std::move(blk_entries)); + workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + } + pipelines.push_back(pipeline); + params_list.push_back(std::move(split_params)); + entries_list.push_back(std::move(split_entries)); + workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + if (use_vec_reduce) { + pipelines.push_back(reduce_pipeline); + params_list.push_back(std::move(reduce_params)); + entries_list.push_back(std::move(reduce_entries)); + workgroups_list.push_back({ (uint32_t) nrows, 1u }); + } + + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + entries_list, workgroups_list); + } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -#endif static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector subs; uint32_t num_batched_kernels = 0; bool contains_set_rows = false; - for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; @@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const ggml_tensor * sinks = tensor->src[4]; + if (Q && K && V) { + GGML_UNUSED(sinks); + const bool kv_direct = (K->type == GGML_TYPE_F16) && + (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = + (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (V->type == K->type); + if (use_vec) { + const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + const size_t limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = + (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!kv_direct) { + bytes_per_kv += std::max(Q->ne[0], V->ne[0]); + } + if (mask != nullptr) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + uint32_t kv_tile = + ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + if (kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= sg_mat_n; + } + } + + const uint32_t vec_nwg_cap = std::max( + 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + + const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2( + (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = + (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + } + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl new file mode 100644 index 00000000000..82d072be73a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -0,0 +1,105 @@ +diagnostic(off, subgroup_uniformity); +enable f16; + +#define Q_TILE 1 +#define KV_TILE 32 +#define WG_SIZE 32 + +struct Params { + offset_mask: u32, + seq_len_q: u32, + seq_len_kv: u32, + stride_mask3: u32, + // Number of KV blocks and Q blocks per batch. + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + nblk0: u32, + nblk1: u32, +}; + +@group(0) @binding(0) var mask: array; +@group(0) @binding(1) var blk: array; +@group(0) @binding(2) var params: Params; + +const MASK_MIN: f32 = -65504.0; +const MASK_MAX: f32 = 65504.0; +var wg_min: array; +var wg_max: array; +var wg_any: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + // Dispatch mapping: + // - x indexes KV blocks + // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk + let kv_blk = wg_id.x; + let y = wg_id.y; + let q_blk = y % params.nblk1; + let batch_idx = y / params.nblk1; + if (kv_blk >= params.nblk0) { + return; + } + + let q_start = q_blk * Q_TILE; + let k_start = kv_blk * KV_TILE; + + let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + + // We keep min/max to classify: + // - fully masked (max <= MASK_MIN) + // - all-zero mask (min == 0 && max == 0) + // - mixed/general mask + var local_min = MASK_MAX; + var local_max = -MASK_MAX; + var local_any = 0u; + + for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { + let q_row = q_start + q_rel; + if (q_row >= params.seq_len_q) { + continue; + } + let row_base = mask_batch_base + q_row * params.seq_len_kv; + for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { + let k_col = k_start + k_rel; + if (k_col >= params.seq_len_kv) { + continue; + } + let mv = f32(mask[row_base + k_col]); + local_min = min(local_min, mv); + local_max = max(local_max, mv); + local_any = 1u; + } + } + + wg_min[local_id.x] = local_min; + wg_max[local_id.x] = local_max; + wg_any[local_id.x] = local_any; + workgroupBarrier(); + + // Thread 0 writes one state per block. + if (local_id.x == 0u) { + var mmin = wg_min[0]; + var mmax = wg_max[0]; + var many = wg_any[0]; + for (var i = 1u; i < WG_SIZE; i += 1u) { + mmin = min(mmin, wg_min[i]); + mmax = max(mmax, wg_max[i]); + many = max(many, wg_any[i]); + } + + var state = 0u; + if (many != 0u) { + if (mmax <= MASK_MIN) { + state = 0u; + } else if (mmin == 0.0 && mmax == 0.0) { + state = 2u; + } else { + state = 1u; + } + } + + let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk; + blk[blk_idx] = state; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl new file mode 100644 index 00000000000..9a0de82a56a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -0,0 +1,78 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +// Default values +#define HEAD_DIM_V 64 +#define WG_SIZE 128 + +struct Params { + nrows: u32, + seq_len_q: u32, + n_heads: u32, + offset_dst: u32, + nwg: u32, + tmp_data_base: u32, + tmp_stats_base: u32, +}; + +@group(0) @binding(0) var tmp: array; +@group(0) @binding(1) var dst: array>; +@group(0) @binding(2) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + let rid = wg_id.x; + if (rid >= params.nrows) { + return; + } + + let rows_per_batch = params.n_heads * params.seq_len_q; + let batch_idx = rid / rows_per_batch; + let rem = rid % rows_per_batch; + let head_idx = rem / params.seq_len_q; + let q_row = rem % params.seq_len_q; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + + let thread = sg_inv_id; + if (params.nwg > subgroup_size) { + return; + } + + let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); + let active_thread = thread < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread); + let m = subgroupMax(mi); + let ms = select(0.0, exp(mi - m), active_thread); + let s = subgroupAdd(si * ms); + let inv_s = select(0.0, 1.0 / s, s != 0.0); + + let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); + for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { + var weighted = vec4(0.0, 0.0, 0.0, 0.0); + if (active_thread) { + let src = row_tmp_base + thread * HEAD_DIM_V + elem_base; + weighted = vec4(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; + } + + let sum_x = subgroupAdd(weighted.x); + let sum_y = subgroupAdd(weighted.y); + let sum_z = subgroupAdd(weighted.z); + let sum_w = subgroupAdd(weighted.w); + + if (thread == 0u) { + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl new file mode 100644 index 00000000000..a52575871ae --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -0,0 +1,729 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + + +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 +#ifndef VEC_NE +#define VEC_NE 4u +#endif + +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +#if defined(KV_Q4_0) +#define NQ 16 +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + +#ifdef BLK + blk_base: u32, + blk_nblk0: u32, + blk_nblk1: u32, +#endif + + tmp_data_base: u32, + tmp_stats_base: u32, + nwg: u32, +}; + +@group(0) @binding(0) var Q: array; +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(2) var V: array; +#else +@group(0) @binding(2) var V: array>; +#endif +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif + +#ifdef BLK +@group(0) @binding(BLK_BINDING) var blk: array; +#endif +@group(0) @binding(TMP_BINDING) var tmp: array; +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; +var blk_state_wg: u32; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + if (apply_mask) { + var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + v += select(mask_val, slope * mask_val, has_bias); + } +#endif + return v; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let iwg = wg_id.x % params.nwg; + let base_wg_id = wg_id.x / params.nwg; + + // batch index + let batch_idx = base_wg_id / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let wg_in_batch = base_wg_id % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let head = f32(head_idx); + let has_bias = params.max_bias > 0.0; + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { +#ifdef BLK + let q_blk = q_row_start / Q_TILE; + let kv_blk = kv_tile / KV_TILE; + let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; + let blk_state_local = blk[blk_idx]; +#else + let blk_state_local = 1u; +#endif + if (local_id.x == 0u) { + blk_state_wg = blk_state_local; + } + workgroupBarrier(); + let blk_state = blk_state_wg; + let skip_tile = blk_state == 0u; + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = f16(0.0); + } + + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(k4.x); + kv_shmem[elem_idx + 1u] = f16(k4.y); + kv_shmem[elem_idx + 2u] = f16(k4.z); + kv_shmem[elem_idx + 3u] = f16(k4.w); + } +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + if (!skip_tile) { + let num_of_threads = subgroup_size / VEC_NE; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + continue; + } + let local_q_row_offset = q_tile_row * HEAD_DIM_QK; + + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = local_q_row_offset + i * 4u; + + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); +#ifdef KV_DIRECT + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + let kv = vec4(K[idx >> 2u]); +#else + let idx = kv_idx * HEAD_DIM_QK + (i * 4u); + let kv = vec4( + f32(kv_shmem[idx + 0u]), + f32(kv_shmem[idx + 1u]), + f32(kv_shmem[idx + 2u]), + f32(kv_shmem[idx + 3u])); +#endif + partial_sum += dot(qv, kv); + } + } + var sum = partial_sum; + // Reduce over tx threads (NL) for this ty stripe. + var tx_delta = num_of_threads >> 1u; + loop { + if (tx_delta == 0u) { + break; + } + let sh = subgroupShuffleDown(sum, tx_delta); + if (tx < tx_delta) { + sum += sh; + } + tx_delta >>= 1u; + } + + let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); + if (tx == 0u && kv_valid) { + let dst_idx = q_tile_row * KV_TILE + kv_idx; + inter_shmem[dst_idx] = f16(sum_bcast); + } + } + } + } + + +#ifdef MASK + let apply_mask = !skip_tile && (blk_state != 2u); + if (apply_mask) { + // load mask tile into shared memory for this KV block + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } + } +#else + let apply_mask = false; +#endif + + workgroupBarrier(); + + // online softmax + if (!skip_tile) { + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + } + + // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(v4.x); + kv_shmem[elem_idx + 1u] = f16(v4.y); + kv_shmem[elem_idx + 2u] = f16(v4.z); + kv_shmem[elem_idx + 3u] = f16(v4.w); + } +#endif + + workgroupBarrier(); + + if (!skip_tile) { + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + let ne_threads : u32 = VEC_NE; + let nl_threads = max(1u, subgroup_size / ne_threads); + let tx_pv = sg_inv_id % nl_threads; + let ty_pv = sg_inv_id / nl_threads; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { + var lo = vec4(0.0, 0.0, 0.0, 0.0); + for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { + let kv_idx = cc * ne_threads + ty_pv; + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + + let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); +#ifdef KV_DIRECT + let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + let v4 = vec4(V[v_idx >> 2u]); +#else + let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; + let v4 = vec4( + f32(kv_shmem[v_idx + 0u]), + f32(kv_shmem[v_idx + 1u]), + f32(kv_shmem[v_idx + 2u]), + f32(kv_shmem[v_idx + 3u])); +#endif + lo += p * v4; + } + + var lo_x = lo.x; + var lo_y = lo.y; + var lo_z = lo.z; + var lo_w = lo.w; + // Reduce over ty threads (NE) for this tx thread. + var ty_delta = ne_threads >> 1u; + loop { + if (ty_delta == 0u) { + break; + } + let thread_delta = ty_delta * nl_threads; + let shx = subgroupShuffleDown(lo_x, thread_delta); + let shy = subgroupShuffleDown(lo_y, thread_delta); + let shz = subgroupShuffleDown(lo_z, thread_delta); + let shw = subgroupShuffleDown(lo_w, thread_delta); + if (ty_pv < ty_delta) { + lo_x += shx; + lo_y += shy; + lo_z += shz; + lo_w += shw; + } + ty_delta >>= 1u; + } + + if (ty_pv == 0u) { + let elem_base = vec_col * 4u; + let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; + o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); + o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); + o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); + o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + } + } + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // Sinks are global terms and must be applied exactly once across split workgroups. + if (iwg == 0u) { + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = new_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); + } + } + workgroupBarrier(); + } +#endif + let rows_per_batch = params.n_heads * params.seq_len_q; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } + + if (params.nwg == 1u) { + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base: u32 = + params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } else { + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[i0]); + tmp[tbase + 1u] = f32(o_shmem[i1]); + tmp[tbase + 2u] = f32(o_shmem[i2]); + tmp[tbase + 3u] = f32(o_shmem[i3]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + } + } + } +} From 321f62823902b890bc1eb5594f937e853c6afc3b Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 3 Apr 2026 10:28:09 +0300 Subject: [PATCH 085/249] rpc : reuse compute graph buffers (llama/21299) Reuse the buffer for the ggml context which is used for creating the compute graph on the server side. This partially addresses a memory leak created by the CUDA backend due to using buffer addresses as cache keys. ref: #21265 ref: #20315 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1378ba9f5bf..4e2f1ab0f23 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1009,8 +1009,8 @@ class rpc_server { bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); struct stored_graph { - ggml_context_ptr ctx_ptr; - ggml_cgraph * graph; + std::vector buffer; + ggml_cgraph * graph; }; private: @@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector & input) { LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); - + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } struct ggml_init_params params = { /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), /*.no_alloc =*/ true, }; ggml_context_ptr ctx_ptr { ggml_init(params) }; @@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector & input) { } ggml_status status = ggml_backend_graph_compute(backends[device], graph); GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); - stored_graphs[device].ctx_ptr.swap(ctx_ptr); stored_graphs[device].graph = graph; return true; } From 3f5117610b9053b1a7a7f9db66181645063ce4cf Mon Sep 17 00:00:00 2001 From: Vishal Singh Date: Fri, 3 Apr 2026 14:49:08 +0530 Subject: [PATCH 086/249] ggml-zendnn : add MUL_MAT_ID op support for MoE models (llama/21315) * ggml-zendnn : add MUL_MAT_ID op support for MoE models - Add MUL_MAT_ID op acceleration for Mixture-of-Experts models - MUL_MAT_ID op fallback to CPU backend if total experts > 32 - Point ZenDNN lib to latest bits ZenDNN-2026-WW13 * ggml-zendnn : add braces to sgemm failure condition for consistency Co-authored-by: Aaron Teo --------- Co-authored-by: Aaron Teo --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 179 +++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index 9bdb4e836d3..4f321a25257 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08 + GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index c8760304008..377303720c7 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat( } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_zendnn_compute_forward_mul_mat_id( + ggml_backend_zendnn_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // expert weights + const ggml_tensor * src1 = dst->src[1]; // inputs + const ggml_tensor * ids = dst->src[2]; // expert ids + + GGML_TENSOR_BINARY_OP_LOCALS + + // exit for no tokens to process + if (ne2 == 0 || ne11 == 0) { + return; + } + + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_experts + + std::vector matrix_row_counts(n_as, 0); + std::vector> matrix_rows(n_as); + + int64_t max_rows = 0; + // group rows by expert (preprocessing step) + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + matrix_rows[i02].push_back({id, iid1}); + matrix_row_counts[i02]++; + if (matrix_row_counts[i02] > max_rows) { + max_rows = matrix_row_counts[i02]; + } + } + } + + if (max_rows == 0) { + return; // no rows to process + } + + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + // size for converting src1 rows to vec_dot_type if needed + const size_t nbw1 = row_size; + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + + // size for MoE gather/scatter buffers + const size_t wdata_cur_size = max_rows * row_size; + const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); + + // allocate single buffer for all needs + const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size; + if (ctx->work_size < total_size) { + ctx->work_data.reset(new char[total_size]); + ctx->work_size = total_size; + } + + // partition the buffer + char * work_data = ctx->work_data.get(); + char * wdata_cur = work_data + src1_conv_size; + char * dst_cur = wdata_cur + wdata_cur_size; + + if (src1->type != vec_dot_type) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13); + void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3; + from_float(src1_f32, src1_conv, ne10); + } + } + } + } + + const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + + // process each expert with gather -> gemm -> scatter pattern + for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + // gather input rows for this expert + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + + std::memcpy( + wdata_cur + ir1 * row_size, + (const char *) wdata + (i11 + i12*ne11) * row_size, + row_size + ); + } + + // batched gemm for all tokens in this expert + if (!ggml_zendnn_sgemm(ctx, + ne01, // m + cne1, // n + ne10, // k + src0_cur, + ne00, // lda + wdata_cur, + ne10, // ldb + dst_cur, + ne01, // ldc + src0->type, + vec_dot_type, + dst->type)) { + GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + } + + // scatter output rows to destination + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i1 = id; + const int64_t i2 = row_mapping.i2; + + std::memcpy( + (char *) dst->data + i1*nb1 + i2*nb2, + dst_cur + ir1 * ggml_row_size(dst->type, ne01), + ggml_row_size(dst->type, ne01) + ); + } + } +} + // backend interface static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) { @@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); break; + case GGML_OP_MUL_MAT_ID: + ggml_zendnn_compute_forward_mul_mat_id(ctx, node); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const return true; case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { const ggml_tensor * weights = op->src[0]; const ggml_tensor * inputs = op->src[1]; @@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { return false; } + // MUL_MAT_ID performs best with a moderate number of experts due to its + // gather + batched matmul + scatter approach. Future versions will leverage + // ZenDNN's grouped_gemm for better scalability with larger expert counts: + // https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md + if (op->op == GGML_OP_MUL_MAT_ID) { + const int64_t n_experts = weights->ne[2]; + const int64_t max_experts = 32; + if (n_experts > max_experts) { + return false; + } + } switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: From d6cfdc669cad5faa0171011b57df3ed7c1ed4911 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 3 Apr 2026 11:40:14 -0700 Subject: [PATCH 087/249] ggml-webgpu: move from parameter buffer pool to single buffer with offsets (llama/21278) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs * Start work on removing parameter buffer pools * Simplify and optimize further * simplify profile futures * Fix stride * Try using a single command buffer per batch * formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 43 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 758 ++++++++---------- 2 files changed, 379 insertions(+), 422 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 1c56c689312..669d2cd53a8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -437,12 +437,18 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ // Head-dim specializations used by the tuned vec f16 path. switch (key.head_dim_qk) { - case 64: return 2u; - case 96: return 4u; - case 128: return 1u; - case 192: return 2u; - case 576: return 2u; - default: return 1u; + case 64: + return 2u; + case 96: + return 4u; + case 128: + return 1u; + case 192: + return 2u; + case 576: + return 2u; + default: + return 1u; } } @@ -513,9 +519,9 @@ struct ggml_webgpu_flash_attn_blk_shader_lib_context { }; inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { std::vector defines; std::string variant = "flash_attn_vec_blk"; @@ -1857,9 +1863,8 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); if (context.key.use_vec) { q_tile = 1; kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); @@ -1885,14 +1890,14 @@ class ggml_webgpu_shader_lib { } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + pipeline.context = decisions; flash_attn_pipelines[context.key] = pipeline; return flash_attn_pipelines[context.key]; } @@ -1905,7 +1910,7 @@ class ggml_webgpu_shader_lib { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); flash_attn_blk_pipelines[context.key] = pipeline; return flash_attn_blk_pipelines[context.key]; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e53281bfbbd..5c567dc0df0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -81,12 +81,10 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 96u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_NUM_PARAM_SLOTS \ + (WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10) // a few extra for safety, since some operations may need multiple slots #define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the -// parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -122,87 +120,45 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector free; - - // The pool must be synchronized because - // 1. The memset pool is shared globally by every ggml buffer, - // since allocating a pool per ggml buffer would consume too much memory. - // 2. For the per-thread buffer pools in webgpu_context, - // buffers are allocated and freed in Dawn callbacks, - // which can run on a different thread than the calling thread. - std::mutex mutex; - std::condition_variable cv; - size_t cur_pool_size; - size_t max_pool_size; - wgpu::Device device; - wgpu::BufferUsage dev_buf_usage; - size_t buf_size; - bool should_grow; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - bool should_grow = false, - size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back(dev_buf); +// Slot-based parameter arena for compute graph encoding. Each encoded kernel +// gets a unique uniform-buffer slice within the current batch, and the slot +// cursor is reset immediately after that batch is submitted. +struct webgpu_param_arena { + wgpu::Buffer buffer; + size_t slot_stride = 0; + size_t slot_size = 0; + uint32_t slot_count = 0; + uint32_t next_slot = 0; + + void init(wgpu::Device device, size_t slot_size, uint32_t slot_count, size_t alignment) { + this->slot_stride = ROUNDUP_POW2(slot_size, alignment); + this->slot_size = slot_size; + this->slot_count = slot_count; + this->next_slot = 0; + + ggml_webgpu_create_buffer(device, buffer, this->slot_stride * slot_count, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_param_arena"); + } + + size_t alloc_slot(size_t size) { + GGML_ASSERT(size <= slot_size); + if (next_slot >= slot_count) { + GGML_ABORT("ggml_webgpu: parameter arena exhausted while encoding a batch"); } - } - wgpu::Buffer alloc_bufs() { - std::unique_lock lock(mutex); - if (!free.empty()) { - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; - } - - // Try growing the pool if no free buffers - if (free.empty() && cur_pool_size < max_pool_size && should_grow) { - cur_pool_size++; - lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - - if (!dev_buf) { - GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); - } - return dev_buf; - } - cv.wait(lock, [this] { return !free.empty(); }); - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; + return slot_stride * next_slot++; } - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } + void reset() { next_slot = 0; } void cleanup() { - std::lock_guard lock(mutex); - for (auto & buf : free) { - if (buf) { - buf.Destroy(); - } + if (buffer) { + buffer.Destroy(); + buffer = nullptr; } - free.clear(); } - ~webgpu_buf_pool() { this->cleanup(); } + ~webgpu_param_arena() { this->cleanup(); } }; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -269,10 +225,8 @@ struct webgpu_gpu_profile_buf_pool { }; #endif -struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; +struct webgpu_encoded_op { + uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -305,8 +259,8 @@ struct webgpu_global_context_struct { // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. std::recursive_mutex mutex; - webgpu_buf_pool memset_buf_pool; - std::map memset_pipelines; // variant or type index + wgpu::Buffer memset_params_buf; + webgpu_pipeline memset_pipeline; #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -332,6 +286,10 @@ struct webgpu_global_context_struct { this->get_tensor_staging_buf.Destroy(); this->get_tensor_staging_buf = nullptr; } + if (this->memset_params_buf) { + this->memset_params_buf.Destroy(); + this->memset_params_buf = nullptr; + } #ifdef GGML_WEBGPU_DEBUG if (this->debug_host_buf) { this->debug_host_buf.Destroy(); @@ -347,13 +305,6 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; -struct webgpu_submission { - wgpu::FutureWaitInfo submit_done; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; -#endif -}; - // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -361,9 +312,9 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_buf_pool param_buf_pool; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; size_t memset_bytes_per_thread; }; @@ -448,95 +399,34 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { - switch (status) { - case wgpu::WaitStatus::Success: - return true; - case wgpu::WaitStatus::TimedOut: - if (allow_timeout) { - return false; - } - GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); - return false; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - return false; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - return false; - } -} - #ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); -} - static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures, - bool block) { + std::vector & futures) { if (futures.empty()) { return; } - uint64_t timeout_ms = block ? UINT64_MAX : 0; - if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } - } - } else { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } + constexpr size_t max_futures_per_wait = 64; + + while (!futures.empty()) { + ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX); + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); } } #endif -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & subs, - bool block = true) { - if (subs.empty()) { - return; - } - - bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; - while (blocking_wait) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); -#endif - subs.erase(subs.begin()); - } - blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; - } - - if (subs.empty()) { - return; - } - - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - bool success = ggml_backend_webgpu_handle_wait_status(waitStatus, true); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (success && sub->profile_futures.empty()) { -#else - if (success) { -#endif - sub = subs.erase(sub); - } else { - ++sub; - } - } +static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { + ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, @@ -570,34 +460,10 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, - std::vector & commands, - webgpu_buf_pool & param_buf_pool) { - std::vector command_buffers; - std::vector params_bufs; - webgpu_submission submission; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector> pipeline_name_and_ts_bufs; -#endif - - for (const auto & command : commands) { - command_buffers.push_back(command.commands); - params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - } - ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - param_buf_pool.free_bufs(params_bufs); - }); - submission.submit_done = { p_f }; - #ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, + const std::vector & commands, + std::vector & futures) { for (const auto & command : commands) { auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; @@ -616,15 +482,15 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & // We can't unmap in here due to WebGPU reentrancy limitations. ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); - submission.profile_futures.push_back({ f }); + futures.push_back({ f }); } -#endif - return submission; } +#endif -static webgpu_command ggml_backend_webgpu_build_multi( +static webgpu_encoded_op ggml_backend_webgpu_build_multi( webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, + webgpu_param_arena & param_arena, + wgpu::CommandEncoder & encoder, const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, @@ -633,16 +499,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; + webgpu_encoded_op result = {}; std::vector bind_groups; + std::vector param_offsets; + result.num_kernels = pipelines.size(); for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); + const size_t param_size = params_list[i].size() * sizeof(uint32_t); + const size_t param_offset = param_arena.alloc_slot(param_size); std::vector entries = bind_group_entries_list[i]; uint32_t params_binding_num = entries.size(); - entries.push_back( - { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); + entries.push_back({ .binding = params_binding_num, + .buffer = param_arena.buffer, + .offset = param_offset, + .size = param_arena.slot_size }); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -650,13 +521,12 @@ static webgpu_command ggml_backend_webgpu_build_multi( bind_group_desc.entries = entries.data(); bind_group_desc.label = pipelines[i].name.c_str(); bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); - - params_bufs_list.push_back(params_bufs); + param_offsets.push_back(param_offset); } - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + for (size_t i = 0; i < param_offsets.size(); i++) { + ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(), + params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); @@ -682,29 +552,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( #ifdef GGML_WEBGPU_GPU_PROFILE encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); -#endif - - wgpu::CommandBuffer commands = encoder.Finish(); - webgpu_command result = {}; - result.commands = commands; - result.params_bufs = params_bufs_list; - result.num_kernels = pipelines.size(); -#ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; - // TODO: handle multiple pipeline names result.pipeline_name = pipelines.front().name; #endif return result; } -static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - webgpu_pipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context & ctx, + webgpu_param_arena & param_arena, + wgpu::CommandEncoder & encoder, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + uint32_t wg_y = 1) { + return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder, { pipeline }, @@ -724,10 +586,28 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = - ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector commands = { command }; - std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; + ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); + + entries.push_back( + { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = ctx->memset_pipeline.name.c_str(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->memset_pipeline.pipeline); + pass.SetBindGroup(0, bind_group); + pass.DispatchWorkgroups(wg_x, 1, 1); + pass.End(); + + wgpu::CommandBuffer command = encoder.Finish(); + std::vector commands = { command }; + ctx->queue.Submit(commands.size(), commands.data()); } /** End WebGPU Actions */ @@ -840,7 +720,10 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 return flags; } -static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, @@ -878,10 +761,14 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -940,10 +827,13 @@ static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; @@ -995,13 +885,14 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1056,13 +947,14 @@ static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1112,17 +1004,18 @@ static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); -} - -static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * src3, - ggml_tensor * src4, - ggml_tensor * src5, - ggml_tensor * dst) { + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1197,13 +1090,14 @@ static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -1266,7 +1160,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1277,10 +1171,11 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1332,13 +1227,14 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1477,16 +1373,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +#ifndef __EMSCRIPTEN__ +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -1575,9 +1473,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = - std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1656,9 +1553,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, if (use_blk) { GGML_ASSERT(has_mask); - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; @@ -1729,8 +1626,10 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); } if (use_blk) { - split_entries.push_back( - { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = blk_buf, + .offset = blk_entries[1].offset, + .size = blk_size_bytes }); } split_entries.push_back( { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); @@ -1799,14 +1698,18 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, workgroups_list.push_back({ (uint32_t) nrows, 1u }); } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, entries_list, workgroups_list); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } +#endif // __EMSCRIPTEN__ -static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); @@ -1881,13 +1784,14 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1983,13 +1887,14 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -2039,10 +1944,13 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -2081,10 +1989,13 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -2124,14 +2035,16 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s }; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, + ggml_nrows(src)); } -static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2228,10 +2141,14 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2290,10 +2207,13 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2341,14 +2261,15 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2424,10 +2345,14 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, + ggml_nrows(dst)); } -static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2449,10 +2374,13 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2543,7 +2471,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr workgroups_list.push_back({ wg_x_init, wg_y_init }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, entries_list, workgroups_list); } @@ -2605,11 +2533,14 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, - workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, + entries_list, workgroups_list); } -static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2634,10 +2565,13 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2666,11 +2600,13 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * node) { if (ggml_is_empty(node)) { return std::nullopt; } @@ -2693,18 +2629,18 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - return ggml_webgpu_cpy(ctx, src0, node); + return ggml_webgpu_cpy(ctx, encoder, src0, node); case GGML_OP_SET: - return ggml_webgpu_set(ctx, src0, src1, node); + return ggml_webgpu_set(ctx, encoder, src0, src1, node); case GGML_OP_SET_ROWS: - return ggml_webgpu_set_rows(ctx, src0, src1, node); + return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node); case GGML_OP_GET_ROWS: - return ggml_webgpu_get_rows(ctx, src0, src1, node); + return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); case GGML_OP_MUL_MAT: - return ggml_webgpu_mul_mat(ctx, src0, src1, node); + return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ - return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); + return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); #else return std::nullopt; #endif @@ -2712,22 +2648,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return ggml_webgpu_binary_op(ctx, src0, src1, node); + return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node); case GGML_OP_CONCAT: - return ggml_webgpu_concat(ctx, src0, src1, node); + return ggml_webgpu_concat(ctx, encoder, src0, src1, node); case GGML_OP_REPEAT: - return ggml_webgpu_repeat(ctx, src0, node); + return ggml_webgpu_repeat(ctx, encoder, src0, node); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return ggml_webgpu_row_norm(ctx, src0, node); + return ggml_webgpu_row_norm(ctx, encoder, src0, node); case GGML_OP_ROPE: - return ggml_webgpu_rope(ctx, src0, src1, src2, node); + return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node); case GGML_OP_GLU: - return ggml_webgpu_glu(ctx, src0, src1, node); + return ggml_webgpu_glu(ctx, encoder, src0, src1, node); case GGML_OP_SCALE: - return ggml_webgpu_scale(ctx, src0, node); + return ggml_webgpu_scale(ctx, encoder, src0, node); case GGML_OP_SOFT_MAX: - return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); + return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node); case GGML_OP_UNARY: case GGML_OP_CLAMP: case GGML_OP_FILL: @@ -2738,26 +2674,27 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_COS: case GGML_OP_DIAG: case GGML_OP_TRI: - return ggml_webgpu_unary_op(ctx, src0, node); + return ggml_webgpu_unary_op(ctx, encoder, src0, node); case GGML_OP_SOLVE_TRI: - return ggml_webgpu_solve_tri(ctx, src0, src1, node); + return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node); case GGML_OP_SSM_CONV: - return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node); case GGML_OP_GATED_DELTA_NET: - return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); + return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5], + node); case GGML_OP_PAD: - return ggml_webgpu_pad(ctx, src0, node); + return ggml_webgpu_pad(ctx, encoder, src0, node); case GGML_OP_ARGMAX: - return ggml_webgpu_argmax(ctx, src0, node); + return ggml_webgpu_argmax(ctx, encoder, src0, node); case GGML_OP_ARGSORT: case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k - return ggml_webgpu_argsort(ctx, src0, node); + return ggml_webgpu_argsort(ctx, encoder, src0, node); case GGML_OP_CUMSUM: - return ggml_webgpu_cumsum(ctx, src0, node); + return ggml_webgpu_cumsum(ctx, encoder, src0, node); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: - return ggml_webgpu_sum_rows(ctx, src0, node); + return ggml_webgpu_sum_rows(ctx, encoder, src0, node); default: return std::nullopt; } @@ -2771,30 +2708,42 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector subs; - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; + std::vector commands; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector profile_futures; +#endif + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; + wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); - // Process events and check for completed submissions - ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); + num_batched_kernels = 0; + wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &batch_commands); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); +#endif + ctx->param_arena.reset(); commands.clear(); + batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); } } if (!commands.empty()) { - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &batch_commands); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); +#endif + ctx->param_arena.reset(); commands.clear(); } @@ -2805,6 +2754,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.GetSize()); wgpu::CommandBuffer set_rows_commands = encoder.Finish(); ctx->global_ctx->queue.Submit(1, &set_rows_commands); + } + + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + + if (contains_set_rows) { ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, ctx->set_rows_host_error_buf.GetSize()); const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); @@ -2814,7 +2768,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.Unmap(); } - ggml_backend_webgpu_wait(ctx->global_ctx, subs); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures); +#endif WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3063,18 +3019,16 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = - (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (V->type == K->type); + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && + kv_vec_type_supported && (V->type == K->type); if (use_vec) { const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = - (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; if (!kv_direct) { bytes_per_kv += std::max(Q->ne[0], V->ne[0]); @@ -3084,10 +3038,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = - ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; if (kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { @@ -3097,30 +3050,30 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { nwg <<= 1; } nwg = std::min(nwg, vec_nwg_cap); - const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t align = + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; if (nwg > 1u) { const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( + const size_t tmp_size_bytes = ROUNDUP_POW2( (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; } if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = - (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); res += blk_size_bytes + align; } @@ -3195,11 +3148,11 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = ctx->capabilities.memset_bytes_per_thread; - ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { @@ -3331,9 +3284,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); - ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, ctx->webgpu_global_ctx->memset_params_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -3357,9 +3310,8 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); - webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); + webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); From c03104553133724ed8d8593cb87a6659f1399e68 Mon Sep 17 00:00:00 2001 From: Yarden Tal Date: Mon, 6 Apr 2026 04:30:25 +0300 Subject: [PATCH 088/249] hexagon: slight optimization for argosrt output init (llama/21463) --- ggml/src/ggml-hexagon/htp/argsort-ops.c | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index 170220e8f80..3ec26a4c1ac 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -164,6 +164,12 @@ static void quicksort_values_indices_desc(float * values, int32_t * indices, int if (i < right) quicksort_values_indices_desc(values, indices, i, right); } +// LUT for ramp initialization of argsort output (first 32 members) +int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 +}; + static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { struct htp_argsort_context * actx = (struct htp_argsort_context *)data; struct htp_ops_context * octx = actx->octx; @@ -205,8 +211,12 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { // Padded to 128 bytes. size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t))); float * values_buf = (float *) spad; int32_t * indices_buf = (int32_t *) (spad + values_size); + HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size); + const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut; + const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32); for (uint32_t r = start_row; r < end_row; r++) { uint32_t src_offset = r * nb01; @@ -218,9 +228,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); - // Initialize indices - for (uint32_t j = 0; j < ne00; j++) { - indices_buf[j] = j; + // Initialize indices - Start with values 0..31, add 32 for additional vec iterations + HVX_Vector curr_ind_vec = ind_init_vec; + for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) { + indices_buf_vec[j_vec] = curr_ind_vec; + curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec); } // Sort values and mirror swaps to indices From 42e4a28865c6909d8a5b6390a68740404005aa3f Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 6 Apr 2026 18:28:00 +0800 Subject: [PATCH 089/249] sycl : handle other FA case (llama/21377) --- ggml/src/ggml-sycl/fattn-tile.hpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index c4d24613a55..b4d4e0ae90e 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -1252,6 +1252,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm return; } + { + constexpr int cols_per_block = ncols2*2; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + GGML_ABORT("fatal error"); } From 7b19b94c5dc822a21bdb2e574ece4a7b2316c436 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Tue, 7 Apr 2026 00:04:29 +0530 Subject: [PATCH 090/249] Write an optimized flash_attn_stream_k_fixup kernel (llama/21159) * Write an optimized flash_attn_stream_k_fixup kernel Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst. Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst * Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs * Address review comments * Address review comments * Revert variable names to original --- ggml/src/ggml-cuda/fattn-common.cuh | 178 ++++++++++++++++++++++++---- 1 file changed, 153 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c59a4db3999..beeb5238946 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) -static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, - const int ne11, const int ne12, const int nbatch_fa) { +static __global__ void flash_attn_stream_k_fixup_uniform( + float * __restrict__ dst, + const float2 * __restrict__ dst_fixup, + const int ne01, const int ne02, + const int ne12, const int nblocks_stream_k, + const int gqa_ratio, + const int blocks_per_tile, + const uint3 fd_iter_j_z_ne12, + const uint3 fd_iter_j_z, + const uint3 fd_iter_j) { + constexpr int ncols = ncols1*ncols2; + + const int tile_idx = blockIdx.x; // One block per output tile. + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks. + const int b_first = tile_idx * blocks_per_tile; + const int b_last = b_first + blocks_per_tile - 1; + + const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols); + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm2.y; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup + float dst_val = *dst; + float max_val; + float rowsum; + { + const float2 tmp = dst_fixup[b_last*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Combine with all previous blocks in this tile. + for (int bidx = b_last - 1; bidx >= b_first; --bidx) { + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc]; + + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles +// (blocks_num.x not a multiple of ntiles_dst) +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup_general( + float * __restrict__ dst, + const float2 * __restrict__ dst_fixup, + const int ne01, const int ne02, + const int gqa_ratio, + const int total_work, + const uint3 fd_iter_k_j_z_ne12, + const uint3 fd_iter_k_j_z, + const uint3 fd_iter_k_j, + const uint3 fd_iter_k) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; - - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; - const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0; + const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index - const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); - const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); - const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j); + const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm3.x; const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. @@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup( // Iterate over previous blocks and compute the combined results. // All CUDA blocks that get here must have a previous block that needs a fixup. + const int tile_kbc0 = fastdiv(kbc0, fd_iter_k); int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*total_work / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup( max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) { break; } bidx--; @@ -976,14 +1063,28 @@ void launch_fattn( const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); - const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; + blocks_num.x = ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; + if(use_stream_k) { + const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); + // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup). + // Only do this if the occupancy loss from rounding is acceptable. + const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst; + const int max_efficiency_loss_percent = 5; + const int efficiency_loss_percent = nblocks_stream_k_rounded > 0 + ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw + : 100; + const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent + ? nblocks_stream_k_rounded + : nblocks_stream_k_raw; + + blocks_num.x = nblocks_stream_k; + } + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); } @@ -1063,13 +1164,40 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) { + // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile. + const int nblocks_sk = (int)blocks_num.x; + const int bpt = nblocks_sk / ntiles_dst; + + const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa); + const uint3 fd2 = init_fastdiv_values(ntiles_x); + + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; + + flash_attn_stream_k_fixup_uniform + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, + gqa_ratio, bpt, fd0, fd1, fd2); + } else if (ntiles_dst % blocks_num.x != 0) { + // General fixup for the cases where nblocks_stream_k < ntiles_dst. + const int total_work = ntiles_KV * ntiles_dst; + + const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa); + const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x); + const uint3 fd_k = init_fastdiv_values(ntiles_KV); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup_general <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); + ((float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], gqa_ratio, total_work, + fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); From 0c2fbd4703a7a64a71dc07e60a17d89dac81d57b Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 11:55:21 -0700 Subject: [PATCH 091/249] ggml: add Q1_0 1-bit quantization support (CPU) (llama/21273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml: add Q1_0 and Q1_0_g128 1-bit quantization support (CPU) * add generic fallback for x86 * remove Q1_0 (group size 32) * rename Q1_0_g128 => Q1_0 * fix Q1_0 LlamaFileType Enum * Fix trailing spaces; add generic fallback for othre backends * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * fix /r/n spacing + arch-fallback --------- Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml.h | 4 +- ggml/src/ggml-common.h | 11 +++ ggml/src/ggml-cpu/arch-fallback.h | 7 ++ ggml/src/ggml-cpu/arch/arm/quants.c | 103 ++++++++++++++++++++++ ggml/src/ggml-cpu/arch/loongarch/quants.c | 1 - ggml/src/ggml-cpu/arch/powerpc/quants.c | 1 - ggml/src/ggml-cpu/arch/s390/quants.c | 1 - ggml/src/ggml-cpu/arch/wasm/quants.c | 1 - ggml/src/ggml-cpu/ggml-cpu.c | 6 ++ ggml/src/ggml-cpu/ops.cpp | 2 + ggml/src/ggml-cpu/quants.c | 49 ++++++++++ ggml/src/ggml-cpu/quants.h | 3 + ggml/src/ggml-quants.c | 75 ++++++++++++++++ ggml/src/ggml-quants.h | 3 + ggml/src/ggml.c | 10 +++ 15 files changed, 272 insertions(+), 5 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 669f66b650f..3bb2faa2c66 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -428,7 +428,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) - GGML_TYPE_COUNT = 41, + GGML_TYPE_Q1_0 = 41, + GGML_TYPE_COUNT = 42, }; // precision @@ -465,6 +466,7 @@ extern "C" { GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 92cf739e7a7..f05683b44cd 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -93,6 +93,10 @@ typedef sycl::half2 ggml_half2; // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization +#define QI1_0 (QK1_0 / 32) +#define QR1_0 1 + + #define QI4_0 (QK4_0 / (4 * QR4_0)) #define QR4_0 2 @@ -170,6 +174,13 @@ typedef sycl::half2 ggml_half2; #define GGML_EXTENSION __extension__ #endif // _MSC_VER +#define QK1_0 128 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK1_0 / 8]; // bits / quants +} block_q1_0; +static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 41da829315b..c589a213e9d 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,6 +16,7 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -82,6 +83,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -112,6 +114,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -160,6 +163,7 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -200,6 +204,7 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -240,6 +245,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -303,6 +309,7 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 82b048bb3ae..e09db59cf22 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -137,6 +137,109 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; // 128 + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0f; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks (each has 32 elements) + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + + // Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes) + // Bits are at offset k*4 bytes in x[i].qs + const uint8_t * bits = &x[i].qs[k * 4]; + + // Load 32 int8 values from y + const int8x16_t y0 = vld1q_s8(yb->qs); + const int8x16_t y1 = vld1q_s8(yb->qs + 16); + + // Byte 0-1: bits for y0[0..15] + const uint64_t expand0 = table_b2b_0[bits[0]]; + const uint64_t expand1 = table_b2b_0[bits[1]]; + // Byte 2-3: bits for y1[0..15] + const uint64_t expand2 = table_b2b_0[bits[2]]; + const uint64_t expand3 = table_b2b_0[bits[3]]; + + // Build the sign vectors by reinterpreting the table values + uint8x8_t e0 = vcreate_u8(expand0); + uint8x8_t e1 = vcreate_u8(expand1); + uint8x8_t e2 = vcreate_u8(expand2); + uint8x8_t e3 = vcreate_u8(expand3); + + // Shift right by 4 to get 0 or 1 + int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4)); + int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4)); + int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4)); + int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4)); + + // Convert 0/1 to -1/+1: sign = 2*val - 1 + int8x8_t one = vdup_n_s8(1); + s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1 + s1 = vsub_s8(vadd_s8(s1, s1), one); + s2 = vsub_s8(vadd_s8(s2, s2), one); + s3 = vsub_s8(vadd_s8(s3, s3), one); + + // Combine into 16-element vectors + int8x16_t signs0 = vcombine_s8(s0, s1); + int8x16_t signs1 = vcombine_s8(s2, s3); + + // Multiply signs with y values and accumulate + // dot(signs, y) where signs are +1/-1 + int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0); + int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1); + + // Scale by d1 and accumulate + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1); + } + } + + sumf = vaddvq_f32(sumv); +#else + // Scalar fallback + for (int i = 0; i < nb; i++) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks + for (int k = 0; k < 4; k++) { + const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); + + int sumi = 0; + for (int j = 0; j < QK8_0; j++) { + const int bit_index = k * QK8_0 + j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; + sumi += xi * y[i*4 + k].qs[j]; + } + sumf += d0 * d1 * sumi; + } + } +#endif + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index f531e916b9e..74e0c086c6d 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2156,4 +2156,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index d3dfd049eaf..644c380c738 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2302,4 +2302,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 34184ed8510..500857579a7 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1463,4 +1463,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 74a359e6d12..648c6fcaba7 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1218,4 +1218,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7486acc2b5d..2b3eb5b5ce6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -217,6 +217,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, + [GGML_TYPE_Q1_0] = { + .from_float = quantize_row_q1_0, + .vec_dot = ggml_vec_dot_q1_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, .vec_dot = ggml_vec_dot_q4_0_q8_0, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 765ce07f06c..0b5d6c6df88 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4829,6 +4829,7 @@ void ggml_compute_forward_get_rows( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5554,6 +5555,7 @@ void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f16(params, dst); } break; case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7ebbb9c6f15..f66127c2290 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -22,6 +22,10 @@ #define UNUSED GGML_UNUSED +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q1_0_ref(x, y, k); +} + void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } @@ -116,6 +120,51 @@ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + + float sumi = 0.0f; + + for (int k = 0; k < 4; k++) { + const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); + + int sumi_block = 0; + + for (int j = 0; j < QK8_0; j++) { + const int bit_index = k * QK8_0 + j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; + sumi_block += xi * y[i*4 + k].qs[j]; + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 3584aaa43e8..d4bc87a1c05 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -12,6 +12,7 @@ extern "C" { #endif // Quantization +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -36,6 +37,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dot product +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -68,6 +70,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48695a61ea3..15443aa554a 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -32,6 +32,41 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } +// reference implementation for deterministic creation of model files +void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float sum_abs = 0.0f; + for (int j = 0; j < qk; j++) { + sum_abs += fabsf(x[i*qk + j]); + } + const float d = sum_abs / qk; + + y[i].d = GGML_FP32_TO_FP16(d); + + // Clear all bits first + for (int j = 0; j < qk / 8; ++j) { + y[i].qs[j] = 0; + } + + // Just store sign of each weight directly (no normalization) + for (int j = 0; j < qk; ++j) { + const int bit_index = j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + if (x[i*qk + j] >= 0.0f) { + y[i].qs[byte_index] |= (1 << bit_offset); + } + } + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -339,6 +374,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE } } +void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float neg_d = -d; + + for (int j = 0; j < qk; ++j) { + const int byte_index = j / 8; + const int bit_offset = j % 8; + const uint8_t bit = (x[i].qs[byte_index] >> bit_offset) & 1; + y[i*qk + j] = bit ? d : neg_d; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -1978,6 +2033,22 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G } } +size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q1_0_ref(src, (block_q1_0*)qrow, n_per_row); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + + size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); @@ -5286,6 +5357,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; + case GGML_TYPE_Q1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb); + } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 00604f75c0e..d56c86da890 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -14,6 +14,7 @@ extern "C" { // NOTE: these functions are defined as GGML_API because they used by the CPU backend // Quantization +GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -41,6 +42,7 @@ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_ GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); // Dequantization +GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -90,6 +92,7 @@ GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9b6720c0af..0142498d967 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -651,6 +651,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, }, + [GGML_TYPE_Q1_0] = { + .type_name = "q1_0", + .blck_size = QK1_0, + .type_size = sizeof(block_q1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -1384,6 +1392,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; @@ -7652,6 +7661,7 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { + case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; From 9cbc4b3acb70f1eabd916b7deacf0ba511185ee8 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 7 Apr 2026 05:08:46 +0900 Subject: [PATCH 092/249] ggml-webgpu: Add the support of `MUL_MAT_ID` (llama/21147) * Add mul_mat_id support to WebGPU * Apply suggestion from @reeselevine --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 134 +++++++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 202 ++++++++++++++++++ .../wgsl-shaders/mul_mat_decls.tmpl | 2 + .../ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl | 193 +++++++++++++++++ .../wgsl-shaders/mul_mat_id_gather.wgsl | 55 +++++ 5 files changed, 585 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 669d2cd53a8..c10157766d9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -658,6 +658,26 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** MUL_MAT_ID **/ + +struct ggml_webgpu_mul_mat_id_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + + bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type; + } +}; + +struct ggml_webgpu_mul_mat_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + return seed; + } +}; + /** Cpy **/ struct ggml_webgpu_cpy_pipeline_key { @@ -797,7 +817,10 @@ class ggml_webgpu_shader_lib { std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map mul_mat_id_gather_pipelines; // key is fixed + std::unordered_map + mul_mat_id_pipelines; // src0_type/src1_type std::unordered_map set_rows_pipelines; @@ -1598,6 +1621,115 @@ class ggml_webgpu_shader_lib { return mul_mat_legacy_pipelines[key]; } + webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = mul_mat_id_gather_pipelines.find(1); + if (it != mul_mat_id_gather_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather"); + pipeline.context = decisions; + mul_mat_id_gather_pipelines[1] = pipeline; + return pipeline; + } + + webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + }; + + auto it = mul_mat_id_pipelines.find(key); + if (it != mul_mat_id_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_id"; + defines.push_back("MUL_MAT_ID"); + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + + variant += std::string("_") + src0_name; + break; + } + } + + defines.push_back("SCALAR"); + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); + + auto decisions = std::make_shared(); + decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_pipelines[key] = pipeline; + return mul_mat_id_pipelines[key]; + } + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5c567dc0df0..5b118393640 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + // Get or create pipeline + webgpu_pipeline gather_pipeline, main_pipeline; + + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); + main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); + + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + const uint32_t param_n_tokens = (uint32_t) dst->ne[2]; + + // params for mul_mat_id_gather.wgsl + std::vector gather_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + }; + + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t); + + const size_t gathered_expert_used_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_tokens_align_offset = + ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_count_ids_align_offset = + ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t gathered_count_ids_binding_size = + ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + + // bind group entries for mul_mat_id_gather.wgsl + std::vector gather_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_expert_used_align_offset, + .size = gathered_binding_size }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_tokens_align_offset, + .size = gathered_binding_size }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_count_ids_align_offset, + .size = gathered_count_ids_binding_size }, + }; + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + + const uint32_t gather_total_wg = param_n_expert; + const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); + const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); + + pipelines.push_back(gather_pipeline); + params_list.push_back(std::move(gather_params)); + entries_list.push_back(std::move(gather_entries)); + workgroups_list.push_back({ gather_wg_x, gather_wg_y }); + + // params for mul_mat_id.wgsl + std::vector main_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + // bind group entries for mul_mat_id.wgsl + std::vector main_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_expert_used_align_offset, + .size = gathered_binding_size }, + { .binding = 4, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_tokens_align_offset, + .size = gathered_binding_size }, + { .binding = 5, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_count_ids_align_offset, + .size = gathered_count_ids_binding_size }, + }; + + // Calculate workgroup dimensions + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * main_decisions = static_cast(main_pipeline.context.get()); + + uint32_t wg_m; + + uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m; + uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + uint32_t total_gathered = dst->ne[1] * dst->ne[2]; + uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered); + uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; + uint32_t total_wg = wg_m * max_wg_n; + + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + pipelines.push_back(main_pipeline); + params_list.push_back(std::move(main_params)); + entries_list.push_back(std::move(main_entries)); + workgroups_list.push_back({ wg_x, wg_y }); + + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, + entries_list, workgroups_list); +} + #ifndef __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, wgpu::CommandEncoder & encoder, @@ -2638,6 +2795,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); + case GGML_OP_MUL_MAT_ID: + return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); @@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_MUL_MAT_ID: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + if (src0 && src1) { + const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2]; + const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2]; + res = ROUNDUP_POW2( + res + gathered_size * 2 + gathered_count_ids_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; default: break; } @@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_MUL_MAT_ID: + switch (src1->type) { + case GGML_TYPE_F16: + supports_op |= (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + supports_op = true; + break; + default: + break; + } + break; + default: + break; + } + break; case GGML_OP_FLASH_ATTN_EXT: { #ifndef __EMSCRIPTEN__ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index eb228537bad..ea91c13468f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -42,6 +42,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_FLOAT +#ifndef MUL_MAT_ID #ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -58,6 +59,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 } } #endif // INIT_SRC1_SHMEM_FLOAT +#endif #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl new file mode 100644 index 00000000000..5f763a6400a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -0,0 +1,193 @@ +enable f16; + +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_val(acc: array, TILE_N>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tn][tm]), f32(acc[tn][tm + 1]), f32(acc[tn][tm + 2]), f32(acc[tn][tm + 3])); +} +#endif + +#ifdef SCALAR +fn store_val(acc: array, TILE_N>, tn: u32, tm: u32) -> f32 { + return f32(acc[tn][tm]); +} +#endif + +struct MulMatIdParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] +@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens] +@group(0) @binding(2) var dst: array; // [rows, n_expert_used, n_tokens] +@group(0) @binding(3) var global_gathered_expert_used: array; // [n_expert][n_tokens] +@group(0) @binding(4) var global_gathered_tokens: array; // [n_expert][n_tokens] +@group(0) @binding(5) var gathered_count_ids: array; // [n_expert] + +@group(0) @binding(6) var params: MulMatIdParams; + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var shmem: array; +var gathered_expert_used: array; +var gathered_tokens: array; + +#ifdef INIT_SRC1_SHMEM_FLOAT +fn init_shmem_id_src1(thread_id: u32, offset_src1: u32, rest_token_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + if (tile_n < rest_token_n) { + let global_src10 = k_outer + tile_k; + let expert_used_idx = gathered_expert_used[tile_n] % params.b_ne1; + let token_idx = gathered_tokens[tile_n]; + let src1_idx = offset_src1 + token_idx * params.stride_12 + expert_used_idx * params.stride_11 + global_src10; + let src1_val = select( + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], + global_src10 < params.k); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); + } else { + store_shmem(SHMEM_TYPE(0.0), TILE_SRC0_SHMEM + elem_idx); + } + } +} +#endif // INIT_SRC1_SHMEM_FLOAT + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + var expert_idx:u32 = 0xFFFFFFFFu; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_n_count = (gathered_count_ids[i] + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_per_matrix = wg_m_count * wg_n_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + expert_idx = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let is_valid = expert_idx != 0xFFFFFFFFu; + + var wg_m: u32 = 0; + var wg_n: u32 = 0; + var offset_wg_m: u32 = 0; + var offset_wg_n: u32 = 0; + var rest_token_n: u32 = 0; + var src0_batch_offset: u32 = 0; + + wg_m = wg_in_batch % wg_m_count; + wg_n = wg_in_batch / wg_m_count; + + offset_wg_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + offset_wg_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + if (is_valid) { + rest_token_n = gathered_count_ids[expert_idx] - offset_wg_n; + let global_gathered_base = expert_idx * params.n_tokens + offset_wg_n; + for (var i = thread_id; i < TILE_N * WORKGROUP_SIZE_N && offset_wg_n + i < gathered_count_ids[expert_idx]; i += TOTAL_WORKGROUP_SIZE) { + gathered_expert_used[i] = global_gathered_expert_used[global_gathered_base + i]; + gathered_tokens[i] = global_gathered_tokens[global_gathered_base + i]; + } + src0_batch_offset = params.offset_src0 + expert_idx * params.stride_02; + } + + workgroupBarrier(); + + let output_row_base = offset_wg_m + local_m * TILE_M; + let output_col_base = offset_wg_n + local_n * TILE_N; + + let dst2_stride = params.m * params.n_expert_used; + let dst1_stride = params.m; + + var acc: array, TILE_N>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + if (is_valid) { + init_shmem_src0(thread_id, src0_batch_offset, offset_wg_m, k_outer); + init_shmem_id_src1(thread_id, params.offset_src1, rest_token_n, k_outer); + } + + workgroupBarrier(); + + if (is_valid) { + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tn][tm] += src0_tile[tm] * src1_val; + } + } + } + } + + workgroupBarrier(); + } + + if (is_valid) { + for (var tn = 0u; tn < TILE_N; tn++) { + let n_idx = output_col_base + tn; + if (n_idx < gathered_count_ids[expert_idx]) { + let dst1_idx = gathered_expert_used[n_idx - offset_wg_n]; + let dst2_idx = gathered_tokens[n_idx - offset_wg_n]; + let dst12_offset = params.offset_dst + dst2_idx * dst2_stride + dst1_idx * dst1_stride; + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst12_offset + global_row; + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); + } + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl new file mode 100644 index 00000000000..d79d5f3f282 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -0,0 +1,55 @@ +enable f16; + +struct MulMatIdGatherParams { + offset_ids: u32, + + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + + stride_ids_1: u32, +}; + +@group(0) @binding(0) var ids: array; // [n_expert_used, n_tokens] +@group(0) @binding(1) var global_gathered_expert_used: array; // [n_expert][n_tokens] +@group(0) @binding(2) var global_gathered_tokens: array; // [n_expert][n_tokens] +@group(0) @binding(3) var gathered_count_ids: array; // [n_expert] + +@group(0) @binding(4) var params: MulMatIdGatherParams; + +var count:atomic; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + + let thread_id = local_id.x; + let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup + + if (own_expert < params.n_expert) { + if (thread_id == 0u) { + atomicStore(&count, 0); + } + + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; + } + } + + workgroupBarrier(); + + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); + } + } +} From 1ebf3cafa03bf94ae71795f2ceb4a3b2effc7cea Mon Sep 17 00:00:00 2001 From: PMZFX Date: Tue, 7 Apr 2026 04:12:49 -0400 Subject: [PATCH 093/249] Add Q8_0 reorder optimization (~3x tg speedup on Intel Arc) (llama/21527) Extend the existing reorder optimization to Q8_0. The reorder separates scale factors from weight data for coalesced memory access -- was implemented for Q4_0/Q4_K/Q6_K but Q8_0 was missing. On Arc Pro B70 (Xe2), Q8_0 tg goes from 4.88 to 15.24 t/s (3.1x) on Qwen3.5-27B. BW utilization: 21% -> 66%. The key fix beyond the kernels: Q8_0 was missing from the type check in ggml_backend_sycl_buffer_init_tensor() that allocates the extra struct carrying the reorder flag -- so the optimization was silently skipped. AI (Claude) was used to assist with root cause investigation and writing the kernel code. All code was human-reviewed and tested on real hardware. Fixes: #21517 --- ggml/src/ggml-sycl/dequantize.hpp | 16 +++++ ggml/src/ggml-sycl/dmmv.cpp | 104 +++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 42 +++++++++++- ggml/src/ggml-sycl/mmvq.cpp | 27 +++++++- ggml/src/ggml-sycl/quants.hpp | 21 ++++++ ggml/src/ggml-sycl/vecdotq.hpp | 40 ++++++++++++ 6 files changed, 247 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 3272724f41b..f992db33b2d 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -143,6 +143,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib); + + v.x() = ((const int8_t *)qs)[iqs + 0]; + v.y() = ((const int8_t *)qs)[iqs + 1]; + +#ifdef GGML_SYCL_F16 + v.s0() *= d; + v.s1() *= d; +#else + v.x() *= d; + v.y() *= d; +#endif // GGML_SYCL_F16 +} + static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4f2760110c2..1c8b6f3771f 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -972,6 +972,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, } } +static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + // Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values] + // Cannot reuse dequantize_mul_mat_vec_reorder template because it has + // Q4_0-specific constants hardcoded (d_ptr offset and qs stride). + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row >= nrows) return; + + const int tid = item_ct1.get_local_id(2); + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; + const int ncols_left = ncols % (QK8_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; +#else + float tmp = 0.0f; +#endif + const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs + + int i = 0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // handle remaining columns + for (; i < ncols; i += iter_stride) { + if (tid >= ncols_left/QK8_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // reduce + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif + } + }); + } +} + static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -1122,7 +1219,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 28be4939784..e80ead9aea4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -411,7 +411,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && !g_ggml_sycl_disable_optimize) { ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; tensor->extra = extra; @@ -3254,6 +3254,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: @@ -3266,6 +3267,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; default: return false; @@ -3275,6 +3277,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: return true; @@ -3364,6 +3367,40 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr sycl_ext_free(stream, tmp_buf); } +static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + GGML_ASSERT((size % sizeof(block_q8_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q8_0) == 0)); + int offset_blks = offset / sizeof(block_q8_0); + auto qs_ptr = data_device + offset_blks * QK8_0; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks; + + auto reorder_event = stream->parallel_for( + size / sizeof(block_q8_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q8_0* x = (const block_q8_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK8_0; j++) + { + *((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + sycl_ext_free(stream, tmp_buf); +} + static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q4_K) == 0); GGML_ASSERT(offset % sizeof(block_q4_K) == 0); @@ -3460,6 +3497,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { case GGML_TYPE_Q4_0: reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); break; + case GGML_TYPE_Q8_0: + reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); + break; case GGML_TYPE_Q4_K: reorder_qw_q4_k(data_device, size, 0, stream); break; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 5abc50fabfe..af22b98dddb 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -679,6 +679,25 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1101,7 +1120,13 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 14490fea5be..1f5b62740a8 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -105,6 +105,27 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK8_0; // 32 + static constexpr uint32_t qi = QI8_0; // 8 + static constexpr uint32_t qr = QR8_0; // 1 + static constexpr uint32_t vdr_mmvq = 4; + }; + + // Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN] + // Each block has 32 int8 weights (32 bytes) followed by all scales + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * QK8_0, 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1 +}; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index eab9850aed7..9253168e5ea 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -351,6 +351,46 @@ template <> struct reorder_vec_dot_q_sycl { }; }; +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q8_0; + + using q8_0_block = ggml_sycl_reordered::block_q_t; + using q8_0_traits = typename q8_0_block::traits; + + __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + // Q8_0 values are signed int8, no nibble extraction needed + // Direct dp4a: each int packs 4 int8 values + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // Q8_0 has no bias term (values are signed), so just scale + return d8_0 * sumi * ds8f.x(); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const int8_t * bq8_0 = static_cast(vbq) + ibx_offset.first; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); + int v[q8_0_traits::vdr_mmvq]; + int u[q8_0_traits::vdr_mmvq]; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_int8(bq8_0, iqs + i); + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds); + }; +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { From a1f76fb4cfd05ed08c96d2f569379551f6e6989f Mon Sep 17 00:00:00 2001 From: Antoine Viallon Date: Tue, 7 Apr 2026 12:18:55 +0200 Subject: [PATCH 094/249] ggml-cuda : fix CDNA2 compute capability constant for gfx90a (MI210) (llama/21519) GGML_CUDA_CC_CDNA2 was set to 0x910 Fix by setting the constant to 0x90a to match the actual gfx90a ISA. --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9affe023403..1c9233b4fc1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -65,7 +65,7 @@ #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 From 18c98ffaf7355917935915cfedd95414fccdc1a2 Mon Sep 17 00:00:00 2001 From: mkoker <132301062+mkoker@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:41:29 -0400 Subject: [PATCH 095/249] vulkan: add FA dequant for q4_1, q5_0, q5_1, iq4_nl (llama/21029) Add dequantize4() implementations for Q4_1, Q5_0, Q5_1, and IQ4_NL in the flash attention base shader. Register them in the shader generator, pipeline creation, and enable in the scalar/coopmat1 FA support check. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++-- .../vulkan-shaders/flash_attn_base.glsl | 102 +++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 118 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 15ed5b2a79d..19e7fbdaae7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3447,11 +3447,19 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -3459,6 +3467,10 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -15331,11 +15343,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: - // supported in scalar and coopmat2 paths - break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + // supported in scalar and coopmat2 paths + break; // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -15350,12 +15363,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_XXS: //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - // currently supported only in coopmat2 path - if (!coopmat2) { - return false; - } - break; + default: return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 172d38f034e..b30dee86871 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -110,7 +110,11 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 +#elif defined(DATA_A_Q4_1) +#define BLOCK_BYTE_SIZE 20 +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); @@ -119,7 +123,12 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -127,11 +136,100 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif } } #endif +#if defined(DATA_A_Q5_0) +#define BLOCK_BYTE_SIZE 22 +#elif defined(DATA_A_Q5_1) +#define BLOCK_BYTE_SIZE 24 +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = v_packed.v_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } +} +#endif + + +#if defined(DATA_A_IQ4_NL) +#define BLOCK_BYTE_SIZE 18 + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); + } +} +#endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8186dba36f6..bf04f4822eb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -655,7 +655,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); @@ -666,7 +666,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); From 78b4fd85e13faf0dd25c3e94b3aecac3b4def041 Mon Sep 17 00:00:00 2001 From: Tom Overlund Date: Tue, 7 Apr 2026 07:54:55 -0400 Subject: [PATCH 096/249] ggml: Vulkan build, Linux -- output error string for errno on fork failure (#20868) (llama/20904) --- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bf04f4822eb..7afdcef7d22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -137,6 +137,7 @@ void execute_command(std::vector& command, std::string& stdout_str, pid_t pid = fork(); if (pid < 0) { + std::cerr << strerror(errno) << "\n"; throw std::runtime_error("Failed to fork process"); } From f1d2b83db08b3ae11b72e0044015d1b196bbbdce Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Apr 2026 15:28:27 +0300 Subject: [PATCH 097/249] ggml : deprecate GGML_OP_ADD1 (llama/21363) * ggml : deprecate GGML_OP_ADD1 * cont : remove tests * cont : re-enable vulkan check --- ggml/include/ggml.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3bb2faa2c66..11d3e8a8167 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -902,15 +902,17 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); - GGML_API struct ggml_tensor * ggml_add1( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add instead"); - GGML_API struct ggml_tensor * ggml_add1_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add_inplace instead"); // dst = a // view(dst, nb1, nb2, nb3, offset) += b From 5ef7aafa0678070ce3cb428162c0c51fafc54d51 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 8 Apr 2026 00:57:04 +0800 Subject: [PATCH 098/249] CUDA: check for buffer overlap before fusing (llama/21566) * CUDA: check for buffer overlap before fusing * use ggml_cuda_check_fusion_memory_ranges --- ggml/src/ggml-cuda/ggml-cuda.cu | 138 ++++++++++++++++---------------- 1 file changed, 71 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75b62129ade..25b904b7dc2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3308,6 +3308,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod return true; } +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph, + const int node_idx, + const int node_count, + const int * out_nodes, + const int out_count, + const bool is_topk_moe = false) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { + return true; + } + + return false; + }; + + bool is_ok = true; + // exception for topk-moe, as each row is read entirely before writing + if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } + } + + return is_ok; +} + + static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, @@ -3337,7 +3402,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { - return true; + int out_nodes[] = { node_idx + 4 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3348,7 +3414,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { - return true; + int out_nodes[] = { node_idx + 2 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3474,69 +3541,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } -// returns whether the write (out) nodes overwrite the read nodes in operation -static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph, - int node_idx, - int node_count, - int * out_nodes, - int out_count) { - auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { - const int64_t a_start = (int64_t) a->data; - const int64_t a_end = a_start + ggml_nbytes(a); - - const int64_t b_start = (int64_t) b->data; - const int64_t b_end = b_start + ggml_nbytes(b); - - if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { - return true; - } - - return false; - }; - - bool is_ok = true; - // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok - if (ggml_nrows(cgraph->nodes[node_idx]) == 1) { - return true; - } - - for (int i = 0; i < out_count; ++i) { - const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; - - for (int j = node_idx; j < node_idx + node_count; ++j) { - // Loop over all srcs of all nodes in the fusion. If the src overlaps - // the destination and the src is not an intermediate node that's being - // elided, then disable fusion. - - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; - - if (!src || src->op == GGML_OP_NONE) { - continue; - } - - if (nodes_overlap(dst, src)) { - bool found = false; - - for (int k = node_idx; k < j; ++k) { - if (cgraph->nodes[k] == src) { - found = true; - break; - } - } - - if (!found) { - is_ok = false; - break; - } - } - } - } - } - - return is_ok; -} - static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3734,7 +3738,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; @@ -3750,7 +3754,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud int out_nodes[2] = { i + 1, i + 5 }; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; From d1456437e1867fa957eb298648c68e48261ee476 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 7 Apr 2026 10:30:01 -0700 Subject: [PATCH 099/249] ggml-webgpu: parameterize submission size and add iOS specific limits (llama/21533) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs * Start work on removing parameter buffer pools * Simplify and optimize further * simplify profile futures * Fix stride * Try using a single command buffer per batch * formatting * Add parameters for different browsers in-flight submissions * Update handling of batch size too * Throttle ios as much as possible * Increase timeout for llvm-pipe testing --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 148 ++++++++++++++++++++------- 1 file changed, 113 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5b118393640..3d038924b78 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -16,7 +16,6 @@ #include #include -#include #include #include #ifdef GGML_WEBGPU_GPU_PROFILE @@ -25,7 +24,6 @@ #if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) # include #endif -#include #include #include #include @@ -81,13 +79,13 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u -#define WEBGPU_NUM_PARAM_SLOTS \ - (WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10) // a few extra for safety, since some operations may need multiple slots -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 -#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -252,6 +250,8 @@ struct webgpu_global_context_struct { wgpu::Adapter adapter; wgpu::Device device; wgpu::Queue queue; + uint32_t command_submit_batch_size = WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; + uint32_t max_inflight_batches = UINT32_MAX; webgpu_capabilities capabilities; // Shared buffer to move data from device to host @@ -417,16 +417,72 @@ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & } #endif +template +static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, + T callback_status, + T success_status, + const char * wait_name, + const char * failure_name, + const char * callback_message) { + if (wait_status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: %s timed out after %u ms\n", wait_name, WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + if (wait_status == wgpu::WaitStatus::Error) { + GGML_ABORT("ggml_webgpu: %s failed\n", wait_name); + } + if (callback_status != success_status) { + GGML_ABORT("ggml_webgpu: %s failed with status %d: %s\n", failure_name, static_cast(callback_status), + callback_message); + } +} + +#ifdef __EMSCRIPTEN__ +// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures. +EM_JS(int, ggml_webgpu_is_ios_browser, (), { + const ua = navigator.userAgent; + return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; +}); +#endif + +static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) { +#ifdef __EMSCRIPTEN__ + if (ggml_webgpu_is_ios_browser()) { + return 1; + } +#else + GGML_UNUSED(info); +#endif + + return UINT32_MAX; +} + +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) { +#ifdef __EMSCRIPTEN__ + if (ggml_webgpu_is_ios_browser()) { + return 16; + } +#else + GGML_UNUSED(info); +#endif + + return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; +} + static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { - ctx->instance.WaitAny( - ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); + wgpu::QueueWorkDoneStatus callback_status = wgpu::QueueWorkDoneStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::QueueWorkDoneStatus::Success, + "Queue wait", "Queue work", callback_message.c_str()); } static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, @@ -434,14 +490,31 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", - message.data); - } - }), - UINT64_MAX); + wgpu::MapAsyncStatus callback_status = wgpu::MapAsyncStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::MapAsyncStatus::Success, + "Buffer map wait", "Buffer map", callback_message.c_str()); +} + +static void ggml_backend_webgpu_submit_commands(webgpu_context & ctx, + const wgpu::CommandBuffer commands, + uint32_t & num_inflight_batches) { + if (num_inflight_batches >= ctx->global_ctx->max_inflight_batches) { + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + num_inflight_batches = 0; + } + + ctx->global_ctx->queue.Submit(1, &commands); + num_inflight_batches++; } #ifdef GGML_WEBGPU_DEBUG @@ -2871,9 +2944,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str #ifdef GGML_WEBGPU_GPU_PROFILE std::vector profile_futures; #endif - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; - wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { @@ -2884,10 +2958,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str num_batched_kernels += cmd.value().num_kernels; } - if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { num_batched_kernels = 0; wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &batch_commands); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); #endif @@ -2898,7 +2972,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } if (!commands.empty()) { wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &batch_commands); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); #endif @@ -2912,7 +2986,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, ctx->set_rows_host_error_buf.GetSize()); wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &set_rows_commands); + ggml_backend_webgpu_submit_commands(ctx, set_rows_commands, num_inflight_batches); } ggml_backend_webgpu_wait_queue(ctx->global_ctx); @@ -3363,6 +3437,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } #endif ctx->webgpu_global_ctx->adapter.GetInfo(&info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info); wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support @@ -3483,8 +3559,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); - webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS, - webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); + webgpu_ctx->param_arena.init( + webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, + webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); From d91d1e8e6c362ac503cc659d5be8cafd7c35ab86 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Tue, 7 Apr 2026 21:47:42 +0200 Subject: [PATCH 100/249] ggml-cuda: ds_read_b128 for q4_0 and q4_1 mmq kernels (llama/21168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ds_read_b128 for q4_0 and q4_1 mmq kernels Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both. * Vectorized lds load update: used ggml_cuda_get_max_cpy_bytes and ggml_cuda_memcpy_1 functions for generic implementation * Explicit for loop in mmq, renamed vec into tmp * Fixed max_cpy usage in the loading loop * Fixed typo in q4_1 kernel * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler * Renoved trailing white line 500 * Update mmq.cuh removed other whitelines * Remove trailing whitespaces --------- Co-authored-by: iacopPBK Co-authored-by: Johannes Gäßler Co-authored-by: iacopPBK --- ggml/src/ggml-cuda/mmq.cuh | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 51e8dad4ce7..489d3616bb4 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -386,17 +386,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -489,17 +497,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_1_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -4170,3 +4186,4 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); + From fa2eaa433bb9aba9bee5fefc0341a9a0b9d6091b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 8 Apr 2026 09:05:51 +0800 Subject: [PATCH 101/249] CUDA: make cuda graphs props check faster (llama/21472) * CUDA: compute fast hash instead of expensive props check * use seen node * use memcp --- ggml/src/ggml-cuda/common.cuh | 21 +----- ggml/src/ggml-cuda/ggml-cuda.cu | 113 ++------------------------------ 2 files changed, 6 insertions(+), 128 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1c9233b4fc1..a2960e5ae3c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1157,19 +1157,6 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_cuda_graph_node_properties { - void * node_data; - ggml_op node_op; - enum ggml_type node_type; - int32_t flags; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_data[GGML_MAX_SRC]; - int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; -}; - -static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); - struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1186,13 +1173,7 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; - std::vector props; - - // these are extra tensors (inputs) that participate in the ggml graph but are not nodes - // they properties also have to match in order to be able to safely reuse a CUDA graph - // ref: https://github.com/ggml-org/llama.cpp/pull/18583 - // ref: https://github.com/ggml-org/llama.cpp/pull/19165 - std::vector extra; + std::vector nodes_copy; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 25b904b7dc2..b21196bb4f3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -82,7 +82,6 @@ #include #include #include -#include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -2969,74 +2968,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { return use_cuda_graph; } -static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); - props->node_data = node->data; - props->node_op = node->op; - props->node_type = node->type; - props->flags = node->flags; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - props->ne[i] = node->ne[i]; - props->nb[i] = node->nb[i]; - } - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - continue; - } - - props->src_data[i] = node->src[i]->data; - } - memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); -} - -static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_data && node->op != GGML_OP_VIEW) { - return false; - } - - if (node->op != props->node_op) { - return false; - } - - if (node->type != props->node_type) { - return false; - } - - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != props->ne[i]) { - return false; - } - if (node->nb[i] != props->nb[i]) { - return false; - } - } - - if (node->op != GGML_OP_VIEW) { - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - if (props->src_data[i] != nullptr) { - return false; - } - continue; - } - - if (node->src[i]->data != props->src_data[i]) { - return false; - } - } - } - - if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; - } - - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { - return false; - } - - return true; -} - static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { return cgraph->nodes[0]; } @@ -3048,52 +2979,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); // Check if the graph size has changed - if (graph->props.size() != (size_t)cgraph->n_nodes) { + if ((int)graph->nodes_copy.size() != cgraph->n_nodes) { res = true; - graph->props.resize(cgraph->n_nodes); + graph->nodes_copy.resize(cgraph->n_nodes); } - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - std::unordered_set seen_node; - std::vector srcs_extra; for (int i = 0; i < cgraph->n_nodes; i++) { - bool props_match = true; - - seen_node.insert(cgraph->nodes[i]); - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); - } - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); - - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; - if (src && seen_node.find(src) == seen_node.end()) { - srcs_extra.push_back(src); + if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + res = true; } } - } - - if (graph->extra.size() != (size_t) srcs_extra.size()) { - res = true; - graph->extra.resize(srcs_extra.size()); - } - - for (size_t i = 0; i < srcs_extra.size(); ++i) { - bool props_match = true; - - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); - } - - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); + memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)); } return res; From 15deafa31ecdd0a44a9a41494f9f7785068fe822 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 8 Apr 2026 06:07:47 -0700 Subject: [PATCH 102/249] metal: Q1_0 backend (llama/21528) * initial Q1_0 Metal backend * tuning q1_0 metal kernels * add Q1_0 to test-backend-ops * add Q1_0<->F32 copy test * Apply suggestions from code review Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 10 ++ ggml/src/ggml-metal/ggml-metal-device.m | 2 + ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 187 ++++++++++++++++++++++ 5 files changed, 203 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 89539bd7615..e8548b053e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m smem = 32*sizeof(float)*nr0; suffix = ne00 % 4 == 0 ? "_4" : ""; } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 17d51b11b6e..40cacb46520 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index eb2253e029a..62b028f4a4a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,6 +8,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_Q1_0 8 +#define N_SG_Q1_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3cda21be43e..846225d9077 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q1_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2074211594c..f28bfa0b95b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg } #endif +template +void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + const float neg_d = -d; + + const int byte_offset = il * 2; // il*16 bits = il*2 bytes + const uint8_t b0 = qs[byte_offset]; + const uint8_t b1 = qs[byte_offset + 1]; + + float4x4 reg_f; + + reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01)); + reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02)); + reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04)); + reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08)); + reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10)); + reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20)); + reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40)); + reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80)); + + reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01)); + reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02)); + reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04)); + reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08)); + reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10)); + reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20)); + reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40)); + reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80)); + + reg = (type4x4) reg_f; +} + +template +void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { + const float d = xb->d; + const float neg_d = -d; + const int base = il * 4; + const uint8_t byte = xb->qs[base / 8]; + const int s = base % 8; + + float4 reg_f; + reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1)); + reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1)); + reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1)); + reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1)); + + reg = (type4) reg_f; +} + template void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q1_0(device const float * src, device block_q1_0 & dst) { + float sum_abs = 0.0f; + for (int j = 0; j < QK1_0; j++) { + sum_abs += fabs(src[j]); + } + dst.d = sum_abs / QK1_0; + + for (int j = 0; j < QK1_0 / 8; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK1_0; j++) { + if (src[j] >= 0.0f) { + dst.qs[j / 8] |= (1 << (j % 8)); + } + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32( } } +// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy) +inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { + device const uint8_t * qs = qb_curr->qs + il / 8; + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; + + float acc = 0.0f; + + acc += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc += select(0.0f, yl[ 1], bool(b0 & 0x02)); + acc += select(0.0f, yl[ 2], bool(b0 & 0x04)); + acc += select(0.0f, yl[ 3], bool(b0 & 0x08)); + acc += select(0.0f, yl[ 4], bool(b0 & 0x10)); + acc += select(0.0f, yl[ 5], bool(b0 & 0x20)); + acc += select(0.0f, yl[ 6], bool(b0 & 0x40)); + acc += select(0.0f, yl[ 7], bool(b0 & 0x80)); + + acc += select(0.0f, yl[ 8], bool(b1 & 0x01)); + acc += select(0.0f, yl[ 9], bool(b1 & 0x02)); + acc += select(0.0f, yl[10], bool(b1 & 0x04)); + acc += select(0.0f, yl[11], bool(b1 & 0x08)); + acc += select(0.0f, yl[12], bool(b1 & 0x10)); + acc += select(0.0f, yl[13], bool(b1 & 0x20)); + acc += select(0.0f, yl[14], bool(b1 & 0x40)); + acc += select(0.0f, yl[15], bool(b1 & 0x80)); + + return qb_curr->d * (2.0f * acc - sumy); +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl( } } +template +void kernel_mul_mv_q1_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK1_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q1_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); + } + + float yl[16]; + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/8); + const short il = (tiisg%8)*16; + + device const float * yb = y + ix*QK1_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + float sumy = 0.f; + + FOR_UNROLL (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; + } + + FOR_UNROLL (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); + } + + yb += QK1_0 * (N_SIMDWIDTH/8); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q1_0_f32")]] +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q1_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; #endif +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q( typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; @@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; +template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q) get_rows_q_t; +template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif +template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; From e70c0d43f4a8dba5dc0ba7faa387f88d7b41a74c Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 8 Apr 2026 06:08:29 -0700 Subject: [PATCH 103/249] webgpu : Query for adapter support when registering WebGPU backend (llama/21579) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 39 +++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3d038924b78..b8df0f4dd05 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4033,8 +4033,14 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); static ggml_backend_webgpu_reg_context ctx; + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_webgpu_reg_i, + /* .context = */ &ctx, + }; + ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 1; + ctx.device_count = 0; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; @@ -4053,19 +4059,28 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); ctx.webgpu_global_ctx->instance = std::move(inst); -#ifdef __EMSCRIPTEN__ - if (ctx.webgpu_global_ctx->instance == nullptr) { - GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); - return nullptr; + wgpu::Adapter adapter; + if (ctx.webgpu_global_ctx->instance != nullptr) { + wgpu::RequestAdapterOptions options = {}; + + // probe for adapter support + ctx.webgpu_global_ctx->instance.WaitAny( + ctx.webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); + } + + if (adapter != nullptr) { + ctx.device_count = 1; } -#endif - GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr); - static ggml_backend_reg reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, - }; return ® } From 16dd1716204773a2f99a29ecf46748fa29d5f2b9 Mon Sep 17 00:00:00 2001 From: RealOrko <45273739+RealOrko@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:40:15 +0100 Subject: [PATCH 104/249] fix: free ctx_copy in ggml_opt_free to plug per-training-session leak (llama/21592) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: free ctx_copy in ggml_opt_free to plug per-training-session leak ggml_opt_alloc populates opt_ctx->ctx_copy via a free+init pair every time the allocated graph shape changes. The last ctx_copy from the final ggml_opt_alloc call survives until ggml_opt_free is invoked, but ggml_opt_free was only freeing ctx_static and ctx_cpu, never ctx_copy. Each opt_ctx lifetime therefore leaks the final per-batch context — ~900 KB for a typical GNN training session in sindarin-pkg-tensor, surfaced via AddressSanitizer. ctx_copy is nullptr-initialized and ggml_free() handles NULL safely, so the new release is guard-free. * Update ggml/src/ggml-opt.cpp Co-authored-by: Johannes Gäßler --------- Co-authored-by: realorko Co-authored-by: Johannes Gäßler --- ggml/src/ggml-opt.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index e078ad14a39..53903defa8f 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); ggml_free(opt_ctx->ctx_cpu); + ggml_free(opt_ctx->ctx_copy); delete opt_ctx; } From 2c7472939fd6d29bc14a8feabdace940d86aecd0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 9 Apr 2026 01:01:56 +0800 Subject: [PATCH 105/249] CUDA: also store `node->src->data` ptrs for equality check (llama/21635) * CUDA: also store node->src->data ptrs for equality check * address review comments --- ggml/src/ggml-cuda/common.cuh | 6 +++++- ggml/src/ggml-cuda/ggml-cuda.cu | 21 ++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a2960e5ae3c..65d7a6e22ae 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1173,7 +1173,11 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; - std::vector nodes_copy; + struct node_properties { + ggml_tensor node; + void * node_src_data_ptrs[GGML_MAX_SRC]; + }; + std::vector node_props; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b21196bb4f3..648124c0d31 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2979,18 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); // Check if the graph size has changed - if ((int)graph->nodes_copy.size() != cgraph->n_nodes) { + if ((int)graph->node_props.size() != cgraph->n_nodes) { res = true; - graph->nodes_copy.resize(cgraph->n_nodes); + graph->node_props.resize(cgraph->n_nodes); } for (int i = 0; i < cgraph->n_nodes; i++) { - if (!res) { - if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { - res = true; - } + ggml_cuda_graph::node_properties prop = {}; + memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor)); + + // if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see: + // https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188 + for (int j = 0; j < GGML_MAX_SRC; ++j) { + prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr; + } + + if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + res = true; } - memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)); + graph->node_props[i] = prop; } return res; From 1d555510dedb4656ad7dedb2279201dccd9f5858 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 9 Apr 2026 07:31:51 +0200 Subject: [PATCH 106/249] vulkan: unify type macros to use Vx instead of _VECx (llama/21605) --- .../vulkan-shaders/mul_mat_vec_iface.glsl | 12 +- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 2 +- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 12 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 16 +- .../vulkan-shaders/mul_mm_funcs.glsl | 192 +++++++++--------- .../vulkan-shaders/mul_mmq_funcs.glsl | 16 +- .../vulkan-shaders/mul_mmq_shmem_types.glsl | 16 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 64 +++--- 10 files changed, 167 insertions(+), 167 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 337dbd796ad..e8d053cdd43 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -6,8 +6,8 @@ #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -#if defined(A_TYPE_VEC4) -layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +#if defined(A_TYPEV4) +layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];}; #endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; @@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -#ifdef B_TYPE_VEC2 -layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#ifdef B_TYPEV2 +layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];}; #endif -#ifdef B_TYPE_VEC4 -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#ifdef B_TYPEV4 +layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];}; #endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 619de054cb8..975cec8013f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 6af5a81587d..93fbacc6282 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 3695b47b98d..54d7e1bcdca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 6ddbed309d7..e99108dc50c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +FLOAT_TYPEV2 get_dm(uint ib) { + return FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif @@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { +FLOAT_TYPEV2 get_dm(uint ib) { const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm); } #endif @@ -304,7 +304,7 @@ vec2 get_dm_scale(uint ib, uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { @@ -422,7 +422,7 @@ vec2 get_dm(uint ib, uint iqs) { const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); // the -1 cancels out the bias in iq1s_grid_gpu - return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); + return FLOAT_TYPEV2(dl, dl * (delta - 1)); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 23f3bd8d6d0..89346e48e06 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -125,8 +125,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit #define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -258,17 +258,17 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; + ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2]; #if defined(DATA_A_F32) || defined(DATA_A_F16) - FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC4 cache_b; + FLOAT_TYPEV4 cache_a[WMITER * TM]; + FLOAT_TYPEV4 cache_b; #else - FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b; + FLOAT_TYPEV2 cache_a[WMITER * TM]; + FLOAT_TYPEV2 cache_b; #endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { - sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); + sums[i] = ACC_TYPEV2(0.0f, 0.0f); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 3f494eb4d5a..9b769bbc887 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if LOAD_VEC_A == 8 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); buf_a[buf_idx ] = aa[0].xy; buf_a[buf_idx + 1] = aa[0].zw; buf_a[buf_idx + 2] = aa[1].xy; @@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #elif LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], + data_a[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_Q4_0) @@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -128,8 +128,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -147,8 +147,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -171,8 +171,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy); const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x), - dl * (qs.y - hm.y)); + buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x), + dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -206,8 +206,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -244,8 +244,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -267,7 +267,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303; const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y); + buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -284,8 +284,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -306,8 +306,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -332,14 +332,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -358,14 +358,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -386,14 +386,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -414,10 +414,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq3xxs_grid[qs]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -436,10 +436,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -456,8 +456,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = float(data_a[ib].d); const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -468,10 +468,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], - kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); - buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], - kvalues_iq4nl[vui >> 12]); + buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -483,10 +483,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, - kvalues_mxfp4[vui2 & 0xF] * d); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, - kvalues_mxfp4[vui2 >> 4] * d); + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } @@ -496,7 +496,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin // Not supported for b_type bf16 because bf16mat2x4 does not exist const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -505,9 +505,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -515,12 +515,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } @@ -531,7 +531,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -541,9 +541,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -553,14 +553,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 9c297d1c60d..59931b04b94 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif } @@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); buf_a[buf_ib].qh = data_a_packed32[ib].qh; } #endif @@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm); buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147 } } @@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147 - buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); } } @@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); + buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); } } @@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint is = iqs_k / 4; const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy; - buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales)); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales)); } } @@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo const uint ib_inner = ib % 4; if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]); } const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; @@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; } else { if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f); } buf_b[buf_ib].qs[iqs * 4 ] = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 1c0f5306f38..c700f6e3f25 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -8,7 +8,7 @@ struct block_a_cache { #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_0) #define QUANT_R_MMQ 2 @@ -22,7 +22,7 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[16/4]; uint32_t qh; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q8_0) #define QUANT_R_MMQ 1 @@ -43,36 +43,36 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[2]; u8vec2 scales; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q3_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #elif defined(DATA_A_Q4_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q6_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #endif struct block_b_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 ds; + FLOAT_TYPEV2 ds; }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7afdcef7d22..11385f93378 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -446,8 +446,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; - base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2"; if (f16acc) { base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } @@ -514,10 +514,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; const std::map float_type_dict_f16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")}, }; // Shaders with f16 B_TYPE @@ -536,9 +536,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; const std::map float_type_dict_bf16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")}, }; // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader @@ -569,10 +569,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; const std::map float_type_dict = { - {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)}, }; // don't generate f32 variants for coopmat2 @@ -676,36 +676,36 @@ void process_shaders() { } } - std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -726,9 +726,9 @@ void process_shaders() { string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); - string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From 4598eb080b513752b90663fe37b9c92ecbf9e5b1 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Thu, 9 Apr 2026 12:06:48 +0530 Subject: [PATCH 107/249] sycl : add flash-attn support for head size 512 (llama/21654) * sycl : add flash-attn support for head size 512 This patch extends the SYCL Flash Attention implementation to support head sizes (DKQ/DV) of 512. Changes: - Added DKQ/DV 512 cases to both tile and vector Flash Attention kernels. - Updated kernel selection logic to allow vector kernels for head sizes up to 512 (previously 256). - Removed unused/redundant AMD and RDNA-specific configuration functions in `fattn-tile.hpp`. - Refactored `ggml_backend_sycl_buffer_init_tensor` to use a switch statement for clearer tensor extra buffer initialization. - Added necessary template instances for the new 512 head size across various quantization types. * remove defunct mxfp4 reorder from setting buffer type --- ggml/src/ggml-sycl/fattn-tile.cpp | 4 + ggml/src/ggml-sycl/fattn-tile.hpp | 151 +++--------------- ggml/src/ggml-sycl/fattn-vec.hpp | 7 + ggml/src/ggml-sycl/fattn.cpp | 4 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 21 ++- .../fattn-tile-instance-dkq512-dv512.cpp | 6 + .../fattn-vec-instance-f16-f16.cpp | 1 + .../fattn-vec-instance-f16-q4_0.cpp | 1 + .../fattn-vec-instance-f16-q4_1.cpp | 1 + .../fattn-vec-instance-f16-q5_0.cpp | 1 + .../fattn-vec-instance-f16-q5_1.cpp | 1 + .../fattn-vec-instance-f16-q8_0.cpp | 1 + .../fattn-vec-instance-q4_0-f16.cpp | 1 + .../fattn-vec-instance-q4_0-q4_0.cpp | 1 + .../fattn-vec-instance-q4_0-q4_1.cpp | 1 + .../fattn-vec-instance-q4_0-q5_0.cpp | 1 + .../fattn-vec-instance-q4_0-q5_1.cpp | 1 + .../fattn-vec-instance-q4_0-q8_0.cpp | 1 + .../fattn-vec-instance-q4_1-f16.cpp | 1 + .../fattn-vec-instance-q4_1-q4_0.cpp | 1 + .../fattn-vec-instance-q4_1-q4_1.cpp | 1 + .../fattn-vec-instance-q4_1-q5_0.cpp | 1 + .../fattn-vec-instance-q4_1-q5_1.cpp | 1 + .../fattn-vec-instance-q4_1-q8_0.cpp | 1 + .../fattn-vec-instance-q5_0-f16.cpp | 1 + .../fattn-vec-instance-q5_0-q4_0.cpp | 1 + .../fattn-vec-instance-q5_0-q4_1.cpp | 1 + .../fattn-vec-instance-q5_0-q5_0.cpp | 1 + .../fattn-vec-instance-q5_0-q5_1.cpp | 1 + .../fattn-vec-instance-q5_0-q8_0.cpp | 1 + .../fattn-vec-instance-q5_1-f16.cpp | 1 + .../fattn-vec-instance-q5_1-q4_0.cpp | 1 + .../fattn-vec-instance-q5_1-q4_1.cpp | 1 + .../fattn-vec-instance-q5_1-q5_0.cpp | 1 + .../fattn-vec-instance-q5_1-q5_1.cpp | 1 + .../fattn-vec-instance-q5_1-q8_0.cpp | 1 + .../fattn-vec-instance-q8_0-f16.cpp | 1 + .../fattn-vec-instance-q8_0-q4_0.cpp | 1 + .../fattn-vec-instance-q8_0-q4_1.cpp | 1 + .../fattn-vec-instance-q8_0-q5_0.cpp | 1 + .../fattn-vec-instance-q8_0-q5_1.cpp | 1 + .../fattn-vec-instance-q8_0-q8_0.cpp | 1 + 42 files changed, 95 insertions(+), 134 deletions(-) create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp index 9d4f019cf51..9449d75784d 100644 --- a/ggml/src/ggml-sycl/fattn-tile.cpp +++ b/ggml/src/ggml-sycl/fattn-tile.cpp @@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index b4d4e0ae90e..9ba5296968d 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -67,6 +67,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -131,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co return 0; } -static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) - - return 0; -} - -static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) - - return 0; -} - static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if(fast_fp16_available(cc)) return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); @@ -1293,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + // ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV). + // For DKQ == 576, DV == 512 only GQA-optimized variants are implemented. + if constexpr (DKQ == DV) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { @@ -1347,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp index 48c389052f4..8031acfdff8 100644 --- a/ggml/src/ggml-sycl/fattn-vec.hpp +++ b/ggml/src/ggml-sycl/fattn-vec.hpp @@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0) + #endif // GGML_SYCL_FATTN_VEC_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp index c276ed89827..7c6e6112fdc 100644 --- a/ggml/src/ggml-sycl/fattn.cpp +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -34,6 +34,7 @@ FATTN_VEC_CASE( 64, type_K, type_V) \ FATTN_VEC_CASE(128, type_K, type_V) \ FATTN_VEC_CASE(256, type_K, type_V) \ + FATTN_VEC_CASE(512, type_K, type_V) \ static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; @@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: + case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } @@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; // Todo: Use the XMX kernel if possible: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e80ead9aea4..7f9b2df524e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -411,11 +411,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && - !g_ggml_sycl_disable_optimize) { - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; - tensor->extra = extra; - ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + + if (!g_ggml_sycl_disable_optimize) { + // set reorder extra buffer based on supported type + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K:{ + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); + break; + } + default: + break; + } } if (ggml_is_quantized(tensor->type)) { diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp new file mode 100644 index 00000000000..9a6a1877566 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(512, 512); + diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp index 32cf4f2859b..43ef94c118c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp index a61a19021bb..9404061d456 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp index 63b74fb347a..a8bb9f52d0c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp index 46e2d9853c5..7d61f6ab0af 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp index 7aabb6ff6e4..753bae09f83 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp index 148ea217f62..546a93b2570 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp index 4b169dbcdbc..53c8c2f2654 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp index 79f530b1815..5b409c55f21 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp index 2f7db51ce82..8c4ef588d63 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp index 9e3bf0b14a1..83f0a07552e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp index 18081879cec..9df9b03bba4 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp index 1c387b0d87c..6980c2a65bb 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp index f005b3762cc..bd61bc1dc2b 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp index 3553b1cdd16..492e229a58e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp index 687ec567115..30f88a2ebd5 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp index 2663bfe7466..db76663604e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp index 641b7c7ae2a..1dbcc8a85a8 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp index 3d3181d4719..d30996a6259 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp index 85d5026ad4f..bc0f635d922 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp index 1e81401a2c9..9e0378107cb 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp index 54251473f97..a8535ac9156 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp index d418c1fb21e..43d4fae9a61 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp index 0f26cfabd09..23335a41640 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp index 4fb98723519..52550a33757 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp index 85b79cd1976..4651f14c050 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp index 7348323b28b..2310fd8792c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp index f19af2aa0ba..d2494048bc1 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp index d7075bac600..be3a1fe97f5 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp index 627f9a57755..be0a89409ca 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp index 23304eecd35..6781efcb0d2 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp index 95acb5d4fbf..43a70ae3543 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp index 5e88f4bab8a..fa7eb8163ca 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp index 69f297feb0c..79d9cfbee96 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp index 455842a9421..86befd5d327 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp index f7ef7391571..c2f619b0b16 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp index 1c633bdf2fa..7cf31f8b8a1 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); From f0ee409f7b3c1ae1e0b3c1139ef8e7e02b0bb3b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Apr 2026 10:54:00 +0300 Subject: [PATCH 108/249] metal : add missing mm-id specializations for q1_0 (llama/21662) --- ggml/src/ggml-metal/ggml-metal.metal | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f28bfa0b95b..f67c5cd8a1d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -10079,6 +10079,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif +template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10102,6 +10103,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; From c4c6e143a7731627be5f6d72c4738ac3ca066bd6 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:17:11 +0200 Subject: [PATCH 109/249] ggml : check return value of CUB calls used in argsort and top-k (they all return cudaError_t) (llama/21676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- ggml/src/ggml-cuda/argsort.cu | 32 ++++++++++++++++---------------- ggml/src/ggml-cuda/top-k.cu | 8 ++++---- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 38fdf3678c1..ed4e5de70f5 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -60,24 +60,24 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) ncols * nrows, nrows, // num items, num segments - offset_iterator, offset_iterator + 1, stream); + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, - stream); + stream)); } } @@ -86,22 +86,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, offset_iterator, - offset_iterator + 1, stream); + offset_iterator + 1, stream)); } } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 785a18389f2..59ce36fb1c9 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -25,14 +25,14 @@ static void top_k_cub(ggml_cuda_pool & pool, auto indexes_in = cuda::make_counting_iterator(0); size_t temp_storage_bytes = 0; - DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, - env); + CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env)); ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); - DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, - ncols, k, env); + CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env)); } #elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE From bb895c843d249ee4a15dcfa19caf2d78ad5e2aa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 9 Apr 2026 16:42:19 +0200 Subject: [PATCH 110/249] ggml: backend-agnostic tensor parallelism (experimental) (llama/19378) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml: backend-agnostic tensor parallelism * support for GPT-OSS, Qwen 3 MoE * partial Vulkan fix * add support for 4/8 GPUs * unconditional peer access * re-use buffers + ggml contexts * fix output pattern * NCCL support * GGML: HIP: add RCCL support * Remove shfl and AllReduce from backend interface * move allocation workaround out of ggml-alloc.c * 2d tensor set/get support * Fix the seg fault without NCCL * Apply suggestion from JohannesGaessler * support for tensor dims % n_devs != 0 * fix view_offs scaling * arbitrary num. of GPUs/tensor split * fix compilation * better granularity estimate * Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA. Fix compilation errors. * partial Qwen 3 Next support * Fix qwen3 30b (llama/8) * Fix crash with Qwen-30B-A3B Q4_0 Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation. * Decide block size based on tensor quantization type * Fix crashes due to KV cache serialization (llama/9) KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset. * metal : fix build (llama/7) * static memory allocations, fix usage count * fix tensor granularity * more even memory distribution * use BF16 for allreduce * rebase fixup * better error message for unsupported architectures * Fix device mismatch during scatter of allReduce. (llama/11) There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies * Enable the previous allreduce implementation. It is better in both perf and stability (llama/12) * delay AllReduce for Moe for less I/O * build : clean-up compile warnings * backend : move most of the meta backend API to ggml-backend-impl.h * cont : hide unused public API in the implementation * llama : use llama_device + remove ggml_backend_dev_is_meta() * ggml-backend : remove unused alloc include * minor : remove regex include * ggml : introduce ggml-ext.h for staging new APIs * rebase fixup * fix tests * llama : more robust logic for determining Meta devices (llama/16) * llama : more robust logic for determining Meta devices * cont : fix devs size check Co-authored-by: Johannes Gäßler * cont : fix log type Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler * disable roundtrip for meta backend * fix arch selection * Qwen 3.5 support * fix Gemma 4 MoE * fix OpenVino, SYCL * fix test-llama-archs for CPU-only builds * Fix Qwen 3.5 MoE * disable meta backend tests for WebGPU * tests : filter CPU-based devices from the Meta backend tests (llama/17) * meta : formatting, naming, indentation (llama/18) * formatting : llama-model.cpp * formatting : ggml-ext.h * formatting : ggml-backend-meta.cpp * meta : add TODO * add documentation * better error messages * fix GPT-OSS --------- Co-authored-by: Carl Philipp Klemm Co-authored-by: Gaurav Garg Co-authored-by: Georgi Gerganov --- ggml/CMakeLists.txt | 4 + ggml/include/ggml-backend.h | 26 +- ggml/include/ggml-cuda.h | 3 + ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-alloc.c | 3 + ggml/src/ggml-backend-impl.h | 24 +- ggml/src/ggml-backend-meta.cpp | 1923 +++++++++++++++++ ggml/src/ggml-backend.cpp | 110 +- ggml/src/ggml-blas/ggml-blas.cpp | 2 + ggml/src/ggml-cann/ggml-cann.cpp | 4 + ggml/src/ggml-cpu/amx/amx.cpp | 2 + ggml/src/ggml-cpu/ggml-cpu.cpp | 2 + ggml/src/ggml-cuda/CMakeLists.txt | 10 + ggml/src/ggml-cuda/common.cuh | 8 + ggml/src/ggml-cuda/ggml-cuda.cu | 245 ++- ggml/src/ggml-cuda/vendors/cuda.h | 4 + ggml/src/ggml-cuda/vendors/hip.h | 6 + ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 + ggml/src/ggml-hip/CMakeLists.txt | 12 + ggml/src/ggml-metal/ggml-metal.cpp | 24 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 + ggml/src/ggml-openvino/ggml-openvino.cpp | 4 + ggml/src/ggml-rpc/ggml-rpc.cpp | 4 + ggml/src/ggml-sycl/ggml-sycl.cpp | 6 + ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp | 4 + ggml/src/ggml-virtgpu/ggml-backend.cpp | 2 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 + ggml/src/ggml-zdnn/ggml-zdnn.cpp | 32 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 2 + 30 files changed, 2362 insertions(+), 121 deletions(-) create mode 100644 ggml/src/ggml-backend-meta.cpp diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 5834e544b48..6bf15723b3c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -7,6 +7,8 @@ set(GGML_VERSION_MINOR 9) set(GGML_VERSION_PATCH 11) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) if(GIT_EXE) # Get current git commit hash @@ -204,12 +206,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) +option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON) set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING "ggml: cuda link binary compression mode; requires cuda 12.8+") set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 9fd3f7f32a0..3c06aeaffb1 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -68,7 +68,7 @@ extern "C" { GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst); // // Backend (stream) @@ -83,13 +83,17 @@ extern "C" { GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); // "offset" refers to the offset in tensor->data for setting/getting data - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -109,7 +113,7 @@ extern "C" { // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported - GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); @@ -135,7 +139,9 @@ extern "C" { // integrated GPU device using host memory GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) - GGML_BACKEND_DEVICE_TYPE_ACCEL + GGML_BACKEND_DEVICE_TYPE_ACCEL, + // "meta" device wrapping multiple other devices for tensor parallelism + GGML_BACKEND_DEVICE_TYPE_META, }; // functionality supported by the device @@ -196,7 +202,9 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // Split buffer type for tensor parallelism + // AllReduce operation for tensor parallelism (meta backend) + typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 22ad2c00963..5436c7ef579 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); +// conduct allreduce operation between devices +GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 78853304d9f..48fbe208d90 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -200,6 +200,7 @@ add_library(ggml-base ggml.cpp ggml-alloc.c ggml-backend.cpp + ggml-backend-meta.cpp ggml-opt.cpp ggml-threading.cpp ggml-threading.h diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 7f414b2311c..e9b70398ffc 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -1236,6 +1236,9 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { size_t nbytes_total = 0; + if (ggml_backend_buft_is_meta(buft)) { + return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); + } return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false); } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 59190b7c465..9c56ec30c5f 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -49,6 +49,10 @@ extern "C" { void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) 2d data copies + void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // clear the entire buffer @@ -80,6 +84,20 @@ extern "C" { GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + // + // Backend (meta) + // + + GGML_API bool ggml_backend_is_meta (ggml_backend_t backend); + GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf); + GGML_API bool ggml_backend_buft_is_meta (ggml_backend_buffer_type_t buft); + + GGML_API size_t ggml_backend_meta_n_backends (ggml_backend_t meta_backend); + GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index); + + // temporary workaround to statically allocate tensors from a context in a deduplicated way: + GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); + // // Backend (stream) // @@ -90,8 +108,10 @@ extern "C" { void (*free)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); // (optional) complete all pending operations (required if the backend supports async operations) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp new file mode 100644 index 00000000000..a2ab8872c4a --- /dev/null +++ b/ggml/src/ggml-backend-meta.cpp @@ -0,0 +1,1923 @@ +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-cpp.h" + +// TODO: tmp +#include "ggml-ext.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ggml_backend_meta_device; +struct ggml_backend_meta_buffer_type; +struct ggml_backend_meta_buffer; +struct ggml_backend_meta; + +const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { + switch (split_axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + return "0"; + case GGML_BACKEND_SPLIT_AXIS_1: + return "1"; + case GGML_BACKEND_SPLIT_AXIS_2: + return "2"; + case GGML_BACKEND_SPLIT_AXIS_3: + return "3"; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + return "MIRRORED"; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: + return "PARTIAL"; + case GGML_BACKEND_SPLIT_AXIS_NONE: + return "NONE"; + case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: + return "UNKNOWN"; + default: + GGML_ABORT("fatal error"); + } +} + +// +// meta backend device +// + +struct ggml_backend_meta_device_context { + std::vector simple_devs; + ggml_backend_meta_get_split_state_t get_split_state; + void * get_split_state_ud; + + std::string name; + std::string description; + + ggml_backend_meta_device_context( + std::vector simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) : + simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) { + name = std::string("Meta("); + description = std::string("Meta("); + for (size_t i = 0; i < simple_devs.size(); i++) { + if (i > 0) { + name += ","; + description += ","; + } + name += ggml_backend_dev_name (simple_devs[i]); + description += ggml_backend_dev_description(simple_devs[i]); + } + name += ")"; + description += ")"; + } + + bool operator<(const ggml_backend_meta_device_context & other) const { + return std::tie(simple_devs, get_split_state, get_split_state_ud) + < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud); + } +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); + +static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->name.c_str(); +} + +static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->description.c_str(); +} + +static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + *free = 0; + *total = 0; + for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) { + size_t tmp_free, tmp_total; + ggml_backend_dev_memory(dev, &tmp_free, &tmp_total); + *free += tmp_free; + *total += tmp_total; + } +} + +static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_META; + + GGML_UNUSED(dev); +} + +static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + // TODO replace placeholders + props->name = ggml_backend_meta_device_get_name(dev); + props->description = ggml_backend_meta_device_get_description(dev); + props->type = ggml_backend_meta_device_get_type(dev); + props->device_id = 0; + + ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, // Not implemented. + /* .buffer_from_host_ptr = */ false, // Not implemented. + /* .events = */ false, // Not implemented. + }; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_dev_props tmp_props; + ggml_backend_dev_get_props(simple_dev, &tmp_props); + props->caps.async = props->caps.async && tmp_props.caps.async; + props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer; + props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr; + props->caps.events = props->caps.events && tmp_props.caps.events; + } +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); + +static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(), + [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); }); +} + +static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft); + if (!ggml_backend_dev_is_meta(dev_buft)) { + return false; + } + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context; + if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) { + return false; + } + for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) { + if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) { + return false; + } + } + return true; +} + +static const ggml_backend_device_i ggml_backend_meta_device_iface = { + /* .get_name = */ ggml_backend_meta_device_get_name, + /* .get_description = */ ggml_backend_meta_device_get_description, + /* .get_memory = */ ggml_backend_meta_device_get_memory, + /* .get_type = */ ggml_backend_meta_device_get_type, + /* .get_props = */ ggml_backend_meta_device_get_props, + /* .init_backend = */ ggml_backend_meta_device_init_backend, + /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ nullptr, + /* .supports_op = */ ggml_backend_meta_device_supports_op, + /* .supports_buft = */ ggml_backend_meta_device_supports_buft, + /* .offload_op = */ nullptr, + /* .event_new = */ nullptr, + /* .event_free = */ nullptr, + /* .event_synchronize = */ nullptr, +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) { + return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name; +} + +static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + return meta_dev_ctx->simple_devs.size(); +} + +static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + GGML_ASSERT(index < meta_dev_ctx->simple_devs.size()); + return meta_dev_ctx->simple_devs[index]; +} + +ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { + GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); + // TODO: this is not thread-safe - needs to be fixed + static std::vector> ctxs; + static std::map meta_devs; + + std::vector simple_devs; + simple_devs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_devs.push_back(devs[i]); + } + ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud); + + { + auto it = meta_devs.find(ctx); + if (it != meta_devs.end()) { + return &it->second; + } + } + ctxs.push_back(std::make_unique(ctx)); + + struct ggml_backend_device meta_dev = { + /*iface =*/ ggml_backend_meta_device_iface, + /*reg =*/ nullptr, + /*ctx =*/ ctxs.back().get(), + }; + + auto result = meta_devs.emplace(*ctxs.back(), meta_dev); + return &result.first->second; +} + +// +// meta backend buffer type +// + +struct ggml_backend_meta_buffer_type_context { + std::vector simple_bufts; + + std::string name; + + ggml_backend_meta_buffer_type_context(std::vector simple_bufts) : simple_bufts(std::move(simple_bufts)) { + name = "Meta("; + for (size_t i = 0; i < simple_bufts.size(); i++) { + if (i > 0) { + name += ","; + } + name += ggml_backend_buft_name(simple_bufts[i]); + } + name += ")"; + } + + bool operator<(const ggml_backend_meta_buffer_type_context & other) const { + return simple_bufts < other.simple_bufts; + } +}; + +static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + return meta_buft_ctx->simple_bufts.size(); +} + +static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context; + return meta_buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size()); + return meta_buft_ctx->simple_bufts[index]; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); + +static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alignment = 1; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i)); + max_alignment = std::max(max_alignment, alignment); + GGML_ASSERT(max_alignment % alignment == 0); + } + return max_alignment; +} + +static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_size = SIZE_MAX; + for (size_t i = 0; i < n_simple_bufts; i++) { + max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i))); + } + return max_size; +} + +static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alloc_size = 0; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor); + max_alloc_size = std::max(max_alloc_size, alloc_size); + } + return max_alloc_size; +} + +static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + for (size_t i = 0; i < n_simple_bufts; i++) { + if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) { + return false; + } + } + return true; +} + +static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = { + /* .get_name = */ ggml_backend_meta_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_meta_buffer_type_is_host, +}; + +bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) { + return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) { + static std::map meta_bufts; + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + { + auto it = meta_bufts.find(dev); + if (it != meta_bufts.end()) { + return &it->second; + } + } + + const size_t n_devs = ggml_backend_meta_dev_n_devs(dev); + std::vector simple_bufts; + simple_bufts.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i))); + } + ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts); + + struct ggml_backend_buffer_type meta_buft = { + /*iface =*/ ggml_backend_meta_buffer_type_iface, + /*device =*/ dev, + /*ctx =*/ buft_ctx, + }; + auto result = meta_bufts.emplace(dev, meta_buft); + return &result.first->second; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + ggml_backend_buffer_type_t host_buft = nullptr; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); + if (simple_host_buft == nullptr) { + return nullptr; + } + if (host_buft == nullptr) { + host_buft = simple_host_buft; + } else if (host_buft != simple_host_buft) { + // if different simple devices have different host buffer types, + // we cannot provide a single host buffer type for the meta device + return nullptr; + } + } + return host_buft; +} + +// +// meta backend buffer +// + +struct ggml_backend_meta_buffer_context { + static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); + + std::map, std::pair> split_state_cache; + std::map< const ggml_tensor *, std::vector> simple_tensors; + + struct buffer_config { + ggml_context * ctx; + ggml_backend_buffer_t buf; + + buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} + }; + std::vector buf_configs; + + int debug; + + ggml_backend_meta_buffer_context() { + const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); + debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; + } +}; + +static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (auto & [ctx, buf] : buf_ctx->buf_configs) { + ggml_backend_buffer_free(buf); + ggml_free(ctx); + } + delete buf_ctx; +} + +static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + return buf_ctx->buf_configs.size(); +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + return buf_ctx->buf_configs[index].buf; +} + +static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + + auto it = buf_ctx->simple_tensors.find(tensor); + if (it == buf_ctx->simple_tensors.end()) { + return nullptr; + } + return it->second[index]; +} + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + + auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { + if (a.axis != b.axis) { + return false; + } + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum_a = 0; + for (size_t s = 0; s < a.n_segments; s++) { + sum_a += a.ne[s*n_bufs + j]; + } + int64_t sum_b = 0; + for (size_t s = 0; s < b.n_segments; s++) { + sum_b += b.ne[s*n_bufs + j]; + } + if (sum_a != sum_b) { + return false; + } + } + return true; + }; + + auto handle_generic = [&](const std::vector & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + continue; + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = src_ss[i]; + } else if (!split_states_equal(src_ss[i], ret)) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + break; + } + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + return ret; + }; + + // Some ops process data on a per-row bases: + auto handle_per_row = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0); + return src_ss[0]; + }; + + // Some ops broadcast the src1 data across src0: + auto handle_bin_bcast = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && + tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis || + (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) { + return src_ss[0]; // GGML_OP_ADD_ID + } + GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_concat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0)); + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[1].axis); + return src_ss[1]; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[0].axis); + return src_ss[0]; + } + if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_mul_mat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[0]; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.n_segments = 1; + return ret; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[1]; + ret.n_segments = 1; + return ret; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + } + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + }; + + auto handle_cpy = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + int64_t ne_split_src = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + ne_split_src *= tensor->src[0]->ne[dim]; + } + int64_t ne_split_dst = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + ne_split_dst *= tensor->ne[dim]; + if (ne_split_dst == ne_split_src) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_reshape = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + } + std::vector base_ne_in; + base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); + { + base_ne_in.push_back(1); + int dim = 0; + for (; dim <= src_ss[0].axis; dim++) { + base_ne_in[0] *= tensor->src[0]->ne[dim]; + } + for (; dim <= GGML_MAX_DIMS; dim++) { + base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); + } + } + int64_t base_ne_out = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; + for (const int64_t & bni : base_ne_in) { + if (bni == base_ne_out_next) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + if (base_ne_out_next > base_ne_in[0]) { + GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); + return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + } + base_ne_out = base_ne_out_next; + } + GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_view = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { + return handle_reshape(src_ss); + } + const int axis = src_ss[0].axis; + { + bool all_strides_the_same = true; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) { + continue; + } + if (tensor->nb[dim] != tensor->src[0]->nb[dim]) { + all_strides_the_same = false; + break; + } + } + if (all_strides_the_same) { + return src_ss[0]; + } + } + if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { + for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { + if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + GGML_ABORT("fatal error"); + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + return src_ss[0]; + } + GGML_ABORT("view of permuted tensor not implemented"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + }; + + auto handle_permute = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_transpose = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: { + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_get_rows = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_set_rows = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2])); + return src_ss[0]; + }; + + auto handle_rope = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return src_ss[0]; + }; + + auto handle_pad = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0); + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0); + } + return src_ss[0]; + }; + + auto handle_flash_attn_ext = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + }; + + auto handle_ssm_conv = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == src_ss[1].axis) { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_gated_delta_net = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + }; + + auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { + if (ggml_nelements(tensor) == 0) { + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { + ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); + const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { + const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; + int64_t ne_sum = 0; + for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { + GGML_ASSERT(ret.ne[sj] % granularity == 0); + ne_sum += ret.ne[sj]; + } + GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); + } + return ret; + } + + std::vector src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + continue; + } + src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + } + + ggml_backend_meta_split_state split_state; + switch (tensor->op) { + case GGML_OP_NONE: { + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + } break; + case GGML_OP_DUP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD_ID: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_ADD1: + case GGML_OP_ACC: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SUM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONCAT: { + split_state = handle_concat(src_ss); + } break; + case GGML_OP_SILU_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { + split_state = handle_mul_mat(src_ss); + } break; + case GGML_OP_OUT_PROD: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SET: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CPY: { + split_state = handle_cpy(src_ss); + } break; + case GGML_OP_CONT: + case GGML_OP_RESHAPE: { + split_state = handle_reshape(src_ss); + } break; + case GGML_OP_VIEW: { + split_state = handle_view(src_ss); + } break; + case GGML_OP_PERMUTE: { + split_state = handle_permute(src_ss); + } break; + case GGML_OP_TRANSPOSE: { + split_state = handle_transpose(src_ss); + } break; + case GGML_OP_GET_ROWS: { + split_state = handle_get_rows(src_ss); + } break; + case GGML_OP_GET_ROWS_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SET_ROWS: { + split_state = handle_set_rows(src_ss); + } break; + case GGML_OP_DIAG: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_DIAG_MASK_ZERO: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_ROPE: { + split_state = handle_rope(src_ss); + } break; + case GGML_OP_ROPE_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CLAMP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: + case GGML_OP_UPSCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_PAD: { + split_state = handle_pad(src_ss); + } break; + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ROLL: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_LEAKY_RELU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_FILL: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_FLASH_ATTN_EXT: { + split_state = handle_flash_attn_ext(src_ss); + } break; + case GGML_OP_FLASH_ATTN_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SSM_CONV: { + split_state = handle_ssm_conv(src_ss); + } break; + case GGML_OP_SSM_SCAN: + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_GATED_DELTA_NET: { + split_state = handle_gated_delta_net(src_ss); + } break; + case GGML_OP_UNARY: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + case GGML_OP_CUSTOM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + case GGML_OP_GLU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + default: { + GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } break; + } + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + bool first_src_split_by_axis = true; + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) { + continue; + } + if (first_src_split_by_axis) { + for (size_t j = 0; j < n_bufs; j++) { + // Take over ratio from src: + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[s*n_bufs + j] = 0; + } + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + } + split_state.ne[j] *= tensor->ne[split_state.axis]; + if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { + GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); + split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + } + } + } else { + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum = 0; + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + sum += src_ss[i].ne[s*n_bufs + j]; + } + // Assert that ratio is consistent: + GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); + } + } + first_src_split_by_axis = false; + } + GGML_ASSERT(!first_src_split_by_axis); + } + return split_state; + }; + + const std::pair key = std::make_pair(tensor, assume_sync); + auto it = buf_ctx->split_state_cache.find(key); + if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) { + buf_ctx->split_state_cache.clear(); + it = buf_ctx->split_state_cache.end(); + } + + if (it == buf_ctx->split_state_cache.end()) { + buf_ctx->split_state_cache[key].first = calculate_split_state(); + memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second)); + if (buf_ctx->debug > 0) { + std::string srcs_info; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + continue; + } + if (!srcs_info.empty()) { + srcs_info += ", "; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(split_state.ne[j]); + } + srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; + } + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + } + GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), + ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); + } + } + + ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first; + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE); +#ifndef NDEBUG + if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + int64_t ne_ret = 0; + for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { + ne_ret += ret.ne[sj]; + } + assert(ne_ret == tensor->ne[int(ret.axis)]); + } +#endif // NDEBUG + return ret; +} + +static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_UNUSED(buffer); + return (void *) 0x1000000000000000; // FIXME +} + +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); + GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + GGML_ASSERT(split_state.n_segments <= 16); + + int split_dim = split_state.axis; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ne[k] = tensor->ne[k]; + nb[k] = tensor->nb[k]; + } + + std::vector simple_tensors; + simple_tensors.reserve(n_simple_bufs); + for (size_t j = 0; j < n_simple_bufs; j++) { + ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; + ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + // TODO: the following assert fails for llama-parallel even though the results are correct: + // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); + ne[split_dim] = 0; + for (size_t s = 0; s < split_state.n_segments; s++) { + ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->nb[i] > tensor->nb[split_dim]) { + nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + + ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); + t_ij->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + t_ij->nb[i] = nb[i]; + } + t_ij->flags = tensor->flags; + memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params)); + ggml_set_name(t_ij, tensor->name); + t_ij->buffer = simple_buf; + t_ij->view_src = tensor->view_src; + t_ij->view_offs = tensor->view_offs; + if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { + t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); + if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + GGML_ASSERT(ne[split_dim] != 0 && tensor->ne[split_dim] != 0); + const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; + GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); + + // The offset can be internal to the data split, in those cases the view offset should not be scaled. + // If however, the offset is larger than the data split then it needs to be scaled proportionally. + bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src]; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + const size_t dim_size = tensor->ne[i] * tensor->nb[i]; + if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) { + split_internal_offset = true; + break; + } + } + if (!split_internal_offset) { + t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + if (t_ij->view_src != nullptr) { + t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; + } else if (simple_buf != nullptr) { + t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer)); + } + t_ij->extra = tensor->extra; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t_ij->src[i] = tensor->src[i]; + if (tensor->src[i] == tensor) { + t_ij->src[i] = t_ij; + } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) { + t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j); + } + } + + simple_tensors.push_back(t_ij); + } + buf_ctx->simple_tensors[tensor] = simple_tensors; + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[1] == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[2] == size); + return; + } + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, data, offset, size); + } + } break; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + const int64_t ne = ggml_nelements(tensor); + std::vector tmp; + tmp.reserve(ne); + for (int64_t i = 0; i < ne; i++) { + tmp.push_back(((const float *) data)[i] / n_bufs); + } + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++){ + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get(simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value); + } +} + +static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); + } +} + +static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = { + /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer, + /* .get_base = */ ggml_backend_meta_buffer_get_base, + /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor, + /* .memset_tensor = */ nullptr, // TODO implement + /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_meta_buffer_clear, + /* .reset = */ ggml_backend_meta_buffer_reset, +}; + +bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { + return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); + size_t max_size = 0; + buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); + buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); + } + + return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); +} + +struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); + meta_buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); + } + + ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf; + ggml_backend_meta_buffer_init_tensor(meta_buf, t); + t->data = (void *) 0x2000000000000000; // FIXME + } + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( + meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); + } + return meta_buf; +} + +// +// meta backend +// + +static ggml_guid_t ggml_backend_meta_guid() { + static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda}; + return &guid; +} + +struct ggml_backend_meta_context { + struct cgraph_config { + ggml_cgraph * cgraph_main = nullptr; + int offset = 0; // Node offset vs. original graph + + std::vector cgraphs_aux; + }; + struct backend_config { + ggml_backend_t backend; + + std::vector cgraphs; + std::vector nodes; + ggml_backend_buffer_ptr buf; + + backend_config(ggml_backend_t backend) : backend(backend) {} + }; + std::string name; + std::vector backend_configs; + ggml_context_ptr ctx; + std::vector cgraphs_aux; + std::vector nodes_aux; + int max_nnodes = 0; + size_t max_tmp_size = 0; + size_t max_subgraphs = 0; + + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { + const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + name = "Meta("; + backend_configs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); + if (i > 0) { + name += ","; + } + name += ggml_backend_dev_name(simple_dev); + backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params)); + } + name += ")"; + } + + ~ggml_backend_meta_context() { + for (auto & bc : backend_configs) { + ggml_backend_free(bc.backend); + } + } + + size_t n_reduce_steps() const { + return std::ceil(std::log2(backend_configs.size())); + } +}; + +static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context; + return backend_ctx->name.c_str(); +} + +static void ggml_backend_meta_free(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + delete backend_ctx; + delete backend; +} + +static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_backends; j++) { + ggml_backend_tensor_set_async( + ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_synchronize(ggml_backend_t backend) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + for (size_t i = 0; i < n_backends; i++) { + ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i)); + } +} + +static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(cgraph->grads == nullptr); + const size_t n_backends = ggml_backend_meta_n_backends(backend); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + + bool max_nnodes_raised = false; + if (cgraph->n_nodes > backend_ctx->max_nnodes) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.nodes.resize(cgraph->n_nodes); + bcj.cgraphs.resize(cgraph->n_nodes); + } + backend_ctx->max_nnodes = cgraph->n_nodes; + max_nnodes_raised = true; + } + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); + } + } + + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; + return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); + } + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; + } + + i = get_i_delayed(i); + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; + } + GGML_ASSERT(i_start == cgraph->n_nodes); + } + + if (max_tmp_size > backend_ctx->max_tmp_size) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); + const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step + const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); + } + } + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + } + } + + size_t iga = 0; // i graph aux + size_t ina = 0; // i node aux + + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; + auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { + ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; + memset(ret, 0, sizeof(ggml_tensor)); + ret->op = GGML_OP_NONE; + ret->type = t->type; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ret->ne[k] = t->ne[k]; + ret->nb[k] = t->nb[k]; + } + return ret; + }; + + // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: + auto allreduce_fallback = [&](size_t i) -> ggml_status { + std::vector step_cgraphs(n_backends, nullptr); + + for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) { + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + for (size_t j = 0; j < n_backends; j++) { + const size_t j_other = j ^ offset_j; + if (j_other > j) { + continue; + } + + auto & bcj1 = backend_ctx->backend_configs[j]; + auto & bcj2 = backend_ctx->backend_configs[j_other]; + + ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main->nodes[bcj1.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main->nodes[bcj2.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node1)); + GGML_ASSERT(ggml_is_contiguous(node2)); + + // Tmp tensors to receive P2P copies + ggml_tensor * node_tmp_1 = get_node_aux(node1); + node_tmp_1->buffer = bcj1.buf.get(); + node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.buf.get()); + + ggml_tensor * node_tmp_2 = get_node_aux(node2); + node_tmp_2->buffer = bcj2.buf.get(); + node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.buf.get()); + + // 2 P2P copies: exchange full buffers + ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2); + ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1); + + // Local ADD: node1 += tmp1 (in-place via view) + ggml_tensor * node_red_1 = get_node_aux(node1); + node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; + node_red_1->view_offs = node1->view_offs; + node_red_1->op = GGML_OP_ADD; + node_red_1->src[0] = node1; + node_red_1->src[1] = node_tmp_1; + node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red_1); + + // Local ADD: node2 += tmp2 (in-place via view) + ggml_tensor * node_red_2 = get_node_aux(node2); + node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; + node_red_2->view_offs = node2->view_offs; + node_red_2->op = GGML_OP_ADD; + node_red_2->src[0] = node2; + node_red_2->src[1] = node_tmp_2; + node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red_2); + + // Build 1-node cgraphs for the ADD ops + ggml_cgraph * cgraph_aux_1 = get_cgraph_aux(); + cgraph_aux_1->nodes[0] = node_red_1; + cgraph_aux_1->n_nodes = 1; + step_cgraphs[j] = cgraph_aux_1; + + ggml_cgraph * cgraph_aux_2 = get_cgraph_aux(); + cgraph_aux_2->nodes[0] = node_red_2; + cgraph_aux_2->n_nodes = 1; + step_cgraphs[j_other] = cgraph_aux_2; + } + + // Execute local ADDs for this step + for (size_t j = 0; j < n_backends; j++) { + if (step_cgraphs[j] == nullptr) { + continue; + } + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + return GGML_STATUS_SUCCESS; + }; + + + for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + + if (n_backends > 1 && i < n_subgraphs - 1) { + bool backend_allreduce_success = false; + ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor"); + if (allreduce_tensor) { + std::vector backends; + backends.reserve(n_backends); + std::vector nodes; + nodes.reserve(n_backends); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + backends.push_back(bcj.backend); + ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; + nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); + } + backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends); + } + + if (!backend_allreduce_success) { + const ggml_status status = allreduce_fallback(i); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + } + return GGML_STATUS_SUCCESS; +} + +static const ggml_backend_i ggml_backend_meta_i = { + /* .get_name = */ ggml_backend_meta_get_name, + /* .free = */ ggml_backend_meta_free, + /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, + /* .get_tensor_2d_async = */ nullptr, + /* .set_tensor_2d_async = */ nullptr, + /* .cpy_tensor_async = */ nullptr, + /* .synchronize = */ ggml_backend_meta_synchronize, + /* .graph_plan_create = */ nullptr, + /* .graph_plan_free = */ nullptr, + /* .graph_plan_update = */ nullptr, + /* .graph_plan_compute = */ nullptr, + /* .graph_compute = */ ggml_backend_meta_graph_compute, + /* .event_record = */ nullptr, + /* .event_wait = */ nullptr, + /* .graph_optimize = */ nullptr, +}; + +bool ggml_backend_is_meta(ggml_backend_t backend) { + return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name; +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) { + ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params); + + ggml_backend_t backend = new struct ggml_backend; + backend->guid = ggml_backend_meta_guid(); + backend->iface = ggml_backend_meta_i; + backend->device = dev; + backend->context = backend_ctx; + return backend; +} + +size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs.size(); +} + +ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs[index].backend; +} + diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 22c656996cc..1a555bf2a4d 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -123,7 +123,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized - if (buffer->size == 0) { + if (!ggml_backend_buffer_is_meta(buffer) && buffer->size == 0) { return NULL; } @@ -279,15 +279,57 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten } } +void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set_async(backend, tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.set_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -297,18 +339,62 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); buf->iface.get_tensor(buf, tensor, data, offset, size); } +void ggml_backend_tensor_set_2d(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set(tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + buf->iface.set_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + buf->iface.get_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -388,7 +474,7 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -402,7 +488,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); -#endif +#endif // NDEBUG size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -411,7 +497,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -500,6 +586,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { } void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + GGML_ASSERT(device); memset(props, 0, sizeof(*props)); device->iface.get_props(device, props); } @@ -610,6 +697,8 @@ static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { /* .memset_tensor = */ NULL, /* .set_tensor = */ NULL, /* .get_tensor = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_multi_buffer_clear, /* .reset = */ NULL, @@ -1899,8 +1988,9 @@ enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); - GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= - (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer) || + (char *) addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *) ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); tensor->buffer = buffer; tensor->data = addr; @@ -2174,6 +2264,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, @@ -2186,6 +2278,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index e7a1763b54d..05245b69807 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -262,6 +262,8 @@ static struct ggml_backend_i blas_backend_i = { /* .get_name = */ ggml_backend_blas_get_name, /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 40fe3d82ecc..5fc484b342b 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1457,6 +1457,8 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, /* .clear = */ ggml_backend_cann_buffer_clear, /* .reset = */ NULL, @@ -2698,6 +2700,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 9baf3e025e6..1118f7169c9 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, /* .cpy_tensor = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear, /* .reset = */ nullptr, diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index ddf1737a317..49f840be207 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -195,6 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 419862101d1..b54d4a6b107 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -181,6 +181,16 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) endif() + if (GGML_CUDA_NCCL) + find_package(NCCL) + if (NCCL_FOUND) + add_compile_definitions(GGML_USE_NCCL) + target_link_libraries(ggml-cuda PRIVATE NCCL::NCCL) + else() + message(STATUS "Warning: NCCL not found, performance for multiple CUDA GPUs will be suboptimal") + endif() + endif() + set(CUDA_CXX_FLAGS "") set(CUDA_FLAGS -use_fast_math -extended-lambda) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 65d7a6e22ae..64b91811c39 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -186,6 +186,10 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) +#ifdef GGML_USE_NCCL +#define NCCL_CHECK(err) CUDA_CHECK_GEN(err, ncclSuccess, ncclGetErrorString) +#endif // GGML_USE_NCCL + #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; @@ -1086,6 +1090,10 @@ struct ggml_cuda_device_info { cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; std::array default_tensor_split = {}; + +#ifdef GGML_USE_NCCL + ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; +#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 648124c0d31..841af0726b6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -324,6 +324,28 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } + } + } + +#ifdef GGML_USE_NCCL + int dev_ids[GGML_CUDA_MAX_DEVICES]; + for (int id = 0; id < info.device_count; ++id) { + dev_ids[id] = id; + } + NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); +#endif // GGML_USE_NCCL + return info; } @@ -632,26 +654,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer } static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } @@ -691,6 +733,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, /* .clear = */ ggml_backend_cuda_buffer_clear, /* .reset = */ NULL, @@ -1003,6 +1047,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_cuda_split_buffer_clear, /* .reset = */ NULL, @@ -1079,6 +1125,83 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; +bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) { +#ifdef GGML_USE_NCCL + const int64_t ne = ggml_nelements(tensors[0]); + // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 + // This then causes a crash in this function + if (ne == 0) { + return true; + } + for (size_t i = 0; i < n_backends; ++i) { + GGML_ASSERT(tensors[i] != nullptr); + GGML_ASSERT(ggml_nelements(tensors[i]) == ne); + GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); + } + + const ggml_cuda_device_info info = ggml_cuda_info(); + + // For small tensors, simply reduce them as FP32. + // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. + if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + return true; + } + + // For large tensors it's faster to compress them to BF16 for the reduction: + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + + ggml_cuda_pool_alloc tmp[GGML_CUDA_MAX_DEVICES]; + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + tmp[i].pool = &cuda_ctx->pool(); + tmp[i].alloc(ne); + + ggml_cuda_set_device(i); + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } + + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + + ggml_cuda_set_device(i); + to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } + + return true; +#else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + static bool warning_printed = false; + if (!warning_printed) { + GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); + warning_printed = true; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + GGML_UNUSED_VARS(backends, tensors, n_backends); + return false; +#endif // GGML_USE_NCCL +} + ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -1425,64 +1548,6 @@ static void ggml_cuda_op_mul_mat_cublas( GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } -static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { - static bool peer_access_enabled = false; - - const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; - - if (peer_access_enabled == enable_peer_access) { - return; - } - -#ifdef NDEBUG - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - CUDA_CHECK(cudaDeviceSynchronize()); - } - - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - - for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) { - if (id == id_other) { - continue; - } - if (id != main_device && id_other != main_device) { - continue; - } - - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - if (enable_peer_access) { - cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); - if (err != cudaErrorPeerAccessAlreadyEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } else { - cudaError_t err = cudaDeviceDisablePeerAccess(id_other); - if (err != cudaErrorPeerAccessNotEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } - } - } - } - - ggml_cuda_set_device(main_device); -#endif // NDEBUG - - peer_access_enabled = enable_peer_access; - - GGML_UNUSED(main_device); -} - static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { @@ -2483,11 +2548,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { - // why is this here instead of mul_mat? - if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { - ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); - } - switch (dst->op) { case GGML_OP_ARGMAX: ggml_cuda_argmax(ctx, dst); @@ -2845,21 +2905,43 @@ static void ggml_backend_cuda_free(ggml_backend_t backend) { } static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); } static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream())); } static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { @@ -2870,21 +2952,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) { return false; } // device -> device copy - ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; - ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context; - ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; - ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context; if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); -#endif +#endif // NDEBUG return false; } @@ -2897,7 +2979,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; #else CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); -#endif +#endif // GGML_CUDA_NO_PEER_COPY } // record event on src stream after the copy @@ -4343,6 +4425,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, @@ -5130,6 +5214,9 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_allreduce_tensor; + } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; } diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 07bc47df3b8..323c9801934 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,6 +6,10 @@ #include #include +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + #if CUDART_VERSION >= 11080 #include #define FP8_AVAILABLE diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 9d9ba1ee219..d146e018d94 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,6 +10,11 @@ #include #endif // defined(GGML_HIP_ROCWMMA_FATTN) +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -28,6 +33,7 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index f91bc46552e..ac5baa2acaf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1491,6 +1491,8 @@ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor, /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor, /* .clear = */ ggml_backend_hexagon_buffer_clear, /* .reset = */ NULL, @@ -3002,6 +3004,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 291b4837455..a7d4e0ea2b5 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -47,6 +47,10 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_RCCL) + find_package(rccl REQUIRED) +endif() + if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() @@ -118,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (GGML_HIP_RCCL) + add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL. +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() @@ -142,4 +150,8 @@ if (GGML_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() +if (GGML_HIP_RCCL) + target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl) +endif() + target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 9382ce53b36..4dbf8e6fea9 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -90,6 +90,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, /* .clear = */ ggml_backend_metal_buffer_shared_clear, /* .reset = */ NULL, @@ -158,15 +160,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer } static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_private_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_private_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, }; static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { @@ -563,6 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 6f3fc5886d8..f1a28a7f4cd 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4063,6 +4063,8 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, @@ -5778,6 +5780,8 @@ static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_opencl_buffer_clear, /* .reset = */ ggml_backend_opencl_buffer_reset, diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index b3058b4af73..0c8d3508e87 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -412,6 +412,8 @@ static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = { /* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor, /* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor, /* .clear = */ ggml_backend_openvino_buffer_clear, /* .reset = */ NULL, @@ -617,6 +619,8 @@ static const ggml_backend_i ggml_backend_openvino_interface = { /* .free = */ ggml_backend_openvino_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 4e2f1ab0f23..61bfcc5a675 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -706,6 +706,8 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, /* .clear = */ ggml_backend_rpc_buffer_clear, /* .reset = */ NULL, @@ -894,6 +896,8 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 7f9b2df524e..989c91a6abb 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -638,6 +638,8 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .clear = */ ggml_backend_sycl_buffer_clear, /* .reset = */ ggml_backend_sycl_buffer_reset, @@ -1084,6 +1086,8 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_sycl_split_buffer_clear, /* .reset = */ NULL, @@ -4553,6 +4557,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp index 6b95362dd80..b6c561cd61e 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -101,6 +101,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, @@ -113,6 +115,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index a63ee2b9d2f..2b978556228 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -34,6 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = { /* .free = */ ggml_backend_remoting_free, /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 19e7fbdaae7..20a4d30d5eb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13521,6 +13521,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -14979,6 +14981,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b8df0f4dd05..edfc6579171 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3013,6 +3013,8 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .free = */ ggml_backend_webgpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -3170,6 +3172,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, /* .reset = */ NULL, // TODO: optional, think it coordinates with diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 9b6938abf7e..e6b6fc24fd7 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -313,6 +313,8 @@ static ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = { /* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor, /* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_zdnn_buffer_clear, /* .reset = */ NULL, @@ -417,20 +419,22 @@ static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, } static ggml_backend_i ggml_backend_zdnn_i = { - /* .get_name = */ ggml_backend_zdnn_name, - /* .free = */ ggml_backend_zdnn_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_zdnn_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .graph_optimize = */ NULL, + /* .get_name = */ ggml_backend_zdnn_name, + /* .free = */ ggml_backend_zdnn_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_zdnn_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 377303720c7..fc1df4dbef4 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -407,6 +407,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, From c77a33df06f64eda3cff5dd54a99e7b3fdbb152c Mon Sep 17 00:00:00 2001 From: andyluo7 <43718156+andyluo7@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:13:32 +0300 Subject: [PATCH 111/249] HIP: add CDNA4 (gfx950) architecture support for MI350X/MI355X (llama/21570) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add AMD Instinct MI350X/MI355X (gfx950, CDNA4) support: - vendors/hip.h: Add CDNA4 preprocessor define for __gfx950__ - common.cuh: Add GGML_CUDA_CC_CDNA4 and GGML_CUDA_CC_IS_CDNA4 macros - mma.cuh: Route CDNA4 to compatible MFMA instructions: * f32 matmul: mfma_f32_16x16x4f32 (xf32 variant unavailable on gfx950) * bf16 matmul: mfma_f32_16x16x16bf16_1k (same as CDNA3) * int8 matmul: mfma_i32_16x16x32_i8/32x32x16 (same as CDNA3) - mmq.cuh: Include CDNA4 in stream-k kernel dispatch CDNA4 is largely compatible with CDNA3 except: - No xf32 MFMA (mfma_f32_16x16x8_xf32) — routes to f32 path - Different FP8 format (e4m3fn vs e4m3_fnuz) — not changed here Tested on AMD Instinct MI355X (gfx950), ROCm 7.0.1: - Build: compiles cleanly with -DAMDGPU_TARGETS=gfx950 - llama-bench (Qwen2.5-1.5B Q4_K_M, single GPU): * f16+FA: 40,013 tok/s prefill, 254 tok/s decode * q8_0+FA: functional - Flash attention: works correctly - MMQ: works correctly with stream-k dispatch Co-authored-by: Andy Luo --- ggml/src/ggml-cuda/common.cuh | 4 +++- ggml/src/ggml-cuda/mma.cuh | 17 +++++++++-------- ggml/src/ggml-cuda/mmq.cuh | 2 +- ggml/src/ggml-cuda/vendors/hip.h | 8 ++++++-- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 64b91811c39..56a67f1edc8 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -67,6 +67,7 @@ #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 +#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 @@ -87,7 +88,8 @@ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) #define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4) +#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5d1dadd3e4f..c91dd2d9ad6 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1025,7 +1025,8 @@ namespace ggml_cuda_mma { const floatx2_t& a_frag = reinterpret_cast(A.x[0]); const floatx2_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA1) +#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1) + // CDNA4 (gfx950) does not support xf32 MFMA, use f32 path like CDNA2/CDNA1 #pragma unroll for (int i = 0; i < 2; ++i) { acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); @@ -1187,7 +1188,7 @@ namespace ggml_cuda_mma { #elif defined(AMD_MFMA_AVAILABLE) using floatx4_t = __attribute__((ext_vector_type(4))) float; floatx4_t& acc_frag = reinterpret_cast(D.x[0]); -#if defined(CDNA3) || defined(CDNA2) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); @@ -1216,12 +1217,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; -#if defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) +#elif defined(CDNA2) || defined(CDNA1) acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], @@ -1230,7 +1231,7 @@ namespace ggml_cuda_mma { B.x[1], acc[0], 0, 0, 0); -#endif // defined(CDNA3) +#endif // defined(CDNA4) || defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) @@ -1295,12 +1296,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; -#if defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) +#elif defined(CDNA2) || defined(CDNA1) acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], @@ -1309,7 +1310,7 @@ namespace ggml_cuda_mma { B.x[1], acc[0], 0, 0, 0); -#endif // defined(CDNA3) +#endif // defined(CDNA4) || defined(CDNA3) #else GGML_UNUSED_VARS(D, A, B); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 489d3616bb4..18911141472 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3645,7 +3645,7 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA constexpr int ITER_K = get_iter_k(type); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index d146e018d94..898fec31e36 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -189,6 +189,10 @@ #define GCN #endif // defined(GCN5) || defined(GCN4) +#if defined(__gfx950__) +#define CDNA4 +#endif // defined(__gfx950__) + #if defined(__gfx942__) #define CDNA3 #endif // defined(__gfx942__) @@ -201,9 +205,9 @@ #define CDNA1 #endif // defined(__gfx908__) -#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #define CDNA // For the entire family -#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 From 28347201fcd8771fdec88fbcad39eff597ee7866 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Apr 2026 10:24:09 +0800 Subject: [PATCH 112/249] CUDA: fuse muls (llama/21665) --- ggml/src/ggml-cuda/binbcast.cu | 30 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/binbcast.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 18 +++++++++++------- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 7339fe0c070..adb4d5f0cb9 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, } } +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 62bc950111b..12624785b44 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 841af0726b6..8613d20b9f9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } - if (node->op == GGML_OP_ADD) { + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { int n_fuse = 0; ggml_op ops[8]; - std::fill(ops, ops + 8, GGML_OP_ADD); + std::fill(ops, ops + 8, node->op); for (; n_fuse <= 6; ++n_fuse){ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { @@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { - ggml_tensor fused_add_node; - memcpy(&fused_add_node, node, sizeof(ggml_tensor)); + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); } - fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; - ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue; From 458ad1d93ec9c5c08752cc409cebf09a06ddd8ea Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 10 Apr 2026 01:35:27 -0500 Subject: [PATCH 113/249] vulkan: Support Q1_0 (llama/21539) * vulkan: Support Q1_0 * use get_dm --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 33 +++++++++++++++++++ .../vulkan-shaders/copy_to_quant.comp | 25 ++++++++++++++ .../vulkan-shaders/dequant_funcs.glsl | 24 ++++++++++++++ .../vulkan-shaders/dequant_funcs_cm2.glsl | 16 ++++++++- .../vulkan-shaders/dequant_q1_0.comp | 29 ++++++++++++++++ .../vulkan-shaders/mul_mm_funcs.glsl | 14 ++++++++ .../src/ggml-vulkan/vulkan-shaders/types.glsl | 16 +++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 ++-- 8 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 20a4d30d5eb..977aff62d81 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3512,6 +3512,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -3541,6 +3542,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) @@ -3602,6 +3604,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3624,6 +3627,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3658,6 +3662,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); @@ -3721,6 +3726,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3767,6 +3773,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3811,6 +3818,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3884,6 +3892,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3928,6 +3937,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3954,6 +3964,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4051,6 +4062,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -4075,6 +4087,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -4125,6 +4138,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); @@ -4179,6 +4193,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -4204,6 +4219,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -4229,6 +4245,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -4310,6 +4327,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -4317,6 +4335,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -4329,6 +4348,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ @@ -4346,6 +4366,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef SET_ROWS + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); @@ -6022,6 +6043,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); switch (type) { case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6093,6 +6115,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6158,6 +6181,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6248,6 +6272,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6316,6 +6341,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -7263,6 +7289,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } if (src->type == GGML_TYPE_F32) { switch (to) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -7277,6 +7304,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const if (to == GGML_TYPE_F32) { switch (src->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15269,6 +15297,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15383,6 +15412,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15415,6 +15445,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15438,6 +15469,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15452,6 +15484,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index b8c40eec102..4ffa45485c9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -184,6 +184,31 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_Q1_0) +void quantize(uint dst_idx, uint src_idx) +{ + float sum_abs = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; j++) { + sum_abs += abs(data_s[src_idx + j]); + } + + const float d = sum_abs / QUANT_K_Q1_0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0 / 8; ++j) { + data_q[dst_idx].qs[j] = uint8_t(0); + } + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; ++j) { + if (data_s[src_idx + j] >= 0.0) { + data_q[dst_idx].qs[j / 8] |= uint8_t(1 << (j % 8)); + } + } +} +#endif + #if defined(DATA_A_IQ4_NL) uint best_index(float x) { if (x <= kvalues_iq4nl[0]) return 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 7865a6bda79..ede1275cfc2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -87,6 +87,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec2( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec4( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f, + (bits & 4u) != 0u ? 1.0f : -1.0f, + (bits & 8u) != 0u ? 1.0f : -1.0f); +} +#endif + #if defined(DATA_A_IQ1_S) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint ib32 = iqs / 32; @@ -454,6 +471,13 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 get_dm(uint ib, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + return vec2(d, 0); +} +#endif + #if defined(DATA_A_MXFP4) vec2 get_dm(uint ib, uint a_offset) { return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 8ac6482dc94..03035f28120 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -13,6 +13,18 @@ float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], return vf16[idx]; } +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ1_0 { + block_q1_0 block; +}; + +float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint bit = (uint(bl.block.qs[(idx & 0x78) >> 3]) >> (idx & 0x7)) & 1u; + return bit != 0u ? d : -d; +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -685,7 +697,9 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q1_0) +#define dequantFuncA dequantFuncQ1_0 +#elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp new file mode 100644 index 00000000000..ca0bdbc63e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q1_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid / 4; + const uint ir = tid % 4; + const uint ib = 4*i + ir; + if (ib >= p.nel / 128) { + return; + } + + const uint b_idx = 512*i + 128*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[il]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l] = D_TYPE((bits & (1u << l)) != 0u ? d : -d); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 9b769bbc887..219bd608035 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -130,6 +130,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); +#elif defined(DATA_A_Q1_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 16; + const uint iqs = idx & 0xfu; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[iqs]); + + buf_a[buf_idx ] = FLOAT_TYPEV2((bits & 0x01u) != 0u ? d : -d, (bits & 0x02u) != 0u ? d : -d); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((bits & 0x04u) != 0u ? d : -d, (bits & 0x08u) != 0u ? d : -d); + buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d); + buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index bdb2c09259b..4239070af5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -188,6 +188,22 @@ struct block_q8_0_packed16 #define DATA_A_QUANT_LEGACY #endif +#define QUANT_K_Q1_0 128 +#define QUANT_R_Q1_0 1 + +struct block_q1_0 +{ + float16_t d; + uint8_t qs[QUANT_K_Q1_0 / 8]; +}; + +#if defined(DATA_A_Q1_0) +#define QUANT_K QUANT_K_Q1_0 +#define QUANT_R QUANT_R_Q1_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q1_0 +#endif + #define QUANT_K_Q8_1 32 #define QUANT_R_Q8_1 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 11385f93378..77a55ea812b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -45,6 +45,7 @@ std::string target_cpp = ""; const std::vector type_names = { "f32", "f16", + "q1_0", "q4_0", "q4_1", "q5_0", @@ -553,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; @@ -758,13 +759,13 @@ void process_shaders() { string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); - for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); From 3fc738a8c2c798deef3371c4a5da95aaa251379c Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Fri, 10 Apr 2026 13:52:01 -0400 Subject: [PATCH 114/249] ggml-webgpu: address quantization precision and backend lifecycle managment (llama/21521) * ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops --------- Co-authored-by: Jeremy J. Hartmann --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 55 ++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 37 +++- .../wgsl-shaders/common_decls.tmpl | 139 +++---------- .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 189 +++++++++++------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 161 ++++++++------- .../wgsl-shaders/mul_mat_decls.tmpl | 78 ++++---- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 46 ++--- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 8 +- 8 files changed, 383 insertions(+), 330 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c10157766d9..3de6258c74d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + switch (key.src_type) + { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); @@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib { variant += "_"; variant += type_str; - defines.push_back(std::string("SRC_TYPE=") + type_str); defines.push_back("DST_TYPE=f32"); if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || @@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib { break; default: { - // quantized types std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - defines.push_back(std::string("SRC0_TYPE=") + src0_name); + switch (context.src0->type) + { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC0_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index edfc6579171..3b894a9b9cc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* End Constants */ +static inline wgpu::CallbackMode ggml_webgpu_callback_mode() { +#ifdef __EMSCRIPTEN__ + return wgpu::CallbackMode::AllowProcessEvents; +#else + return wgpu::CallbackMode::AllowSpontaneous; +#endif +} + // This is a "fake" base pointer, since WebGPU buffers do not have pointers to // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, + ggml_webgpu_callback_mode(), [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, std::string callback_message; const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( - buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(), [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); + if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, + ctx->debug_host_buf.GetSize())) { + GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n"); + return; + } const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); @@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & auto ts_bufs = command.timestamp_query_bufs; wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(), [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); @@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, + &options, ggml_webgpu_callback_mode(), [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); @@ -3491,8 +3503,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + ggml_webgpu_callback_mode(), + [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; } @@ -3525,7 +3537,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + &dev_desc, ggml_webgpu_callback_mode(), [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { if (status != wgpu::RequestDeviceStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); @@ -4046,6 +4058,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 0; + // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend + // registry. Recreating it on repeated registry lookups can invalidate + // adapter/device references that are still held by the backend/device layer. + if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + return ® + } + wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); @@ -4063,11 +4082,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); ctx.webgpu_global_ctx->instance = std::move(inst); + // Probe for adapter support wgpu::Adapter adapter; if (ctx.webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - // probe for adapter support ctx.webgpu_global_ctx->instance.WaitAny( ctx.webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index feb0bca3f84..0d3501c34a2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,36 +9,44 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_src0_u16_at(byte_offset: u32) -> u32 { - let word = src0[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; +fn load_u16_at( + buf: ptr, read_write>, + byte_offset: u32) -> u32 { + let word = buf[byte_offset / 4]; + let shift = (byte_offset & 0x2) * 8; + return (word >> shift) & 0xFFFF; } -fn load_src0_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = src0[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = src0[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); +fn load_u32_at( + buf: ptr, read_write>, + byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4; + let shift = (byte_offset & 0x3) * 8; + let lo = buf[word_idx]; + let hi = buf[word_idx + 1]; + let shifted = (lo >> shift) | (hi << (32 - shift)); + return select(shifted, lo, shift == 0); } -fn load_src0_f16_at(byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_src0_u16_at(byte_offset)); +fn load_f16_at( + buf: ptr, read_write>, + byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at(buf, byte_offset)); return f16(packed[0]); } -#endif -#ifdef Q4_0_T -struct q4_0 { - d: f16, - qs: array -}; +fn load_f16_as_f32_at( + buf: ptr, read_write>, + byte_offset: u32) -> f32 { + let word = buf[byte_offset / 4]; + let shift = (byte_offset & 0x2) * 8; + let d_bits = (word >> shift) & 0xFFFF; + return unpack2x16float(d_bits)[0]; +} #endif + + #ifdef Q4_1_T struct q4_1 { d: f16, @@ -47,13 +55,6 @@ struct q4_1 { }; #endif -#ifdef Q5_0_T -struct q5_0 { - d: f16, - qh: array, - qs: array -}; -#endif #ifdef Q5_1_T struct q5_1 { @@ -64,12 +65,6 @@ struct q5_1 { }; #endif -#ifdef Q8_0_T -struct q8_0 { - d: f16, - qs: array -}; -#endif #ifdef Q8_1_T struct q8_1 { @@ -88,14 +83,6 @@ struct q2_K { }; #endif -#ifdef Q3_K_T -struct q3_K { - hmask: array, - qs: array, - scales: array, - d: f16 -}; -#endif #if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { @@ -132,64 +119,6 @@ struct q5_K { }; #endif -#ifdef Q6_K_T -struct q6_K { - ql: array, - qh: array, - scales: array, - d: f16 -}; -#endif - -#ifdef IQ2_XXS_T -struct iq2_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ2_XS_T -struct iq2_xs { - d: f16, - qs: array, - scales: array -}; -#endif - -#ifdef IQ2_S_T -struct iq2_s { - d: f16, - qs: array, - qh: array, - scales: array -}; -#endif - -#ifdef IQ3_XXS_T -struct iq3_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ3_S_T -struct iq3_s { - d: f16, - qs: array, - qh: array, - signs: array, - scales: array -}; -#endif - -#ifdef IQ1_S_T -struct iq1_s { - d: f16, - qs: array, - qh: array -}; -#endif - #ifdef IQ1_M_T struct iq1_m { qs: array, @@ -198,17 +127,9 @@ struct iq1_m { }; #endif -#ifdef IQ4_NL_T -struct iq4_nl { - d: f16, - qs: array, -}; -#endif - #ifdef IQ4_XS_T struct iq4_xs { - d: f16, - scales_h: f16, + d_scales_h: u32, scales_l: u32, qs: array }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index d9eb6a3567e..3c8b84c9ac3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_0 = src[src_base + offset]; - let d = f32(block_q4_0.d); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; + dst[dst_offset + 16u] = q_hi; } } } @@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_0 = src[src_base + offset]; - let d = f32(block_q5_0.d); - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + let qh_packed = load_u32_at(&src, block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at(&src, q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; dst[dst_offset + 16] = q_hi; @@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q8_0 = src[src_base + offset]; - let d = f32(block_q8_0.d); - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + for (var j: u32 = 0u; j < 8u; j++) { + let q_byte_offset = block_byte_base + 2u + j * 4u; + let q_packed = load_u32_at(&src, q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; + let dst_offset = dst_base + offset * 32u + j * 4u + k; dst[dst_offset] = q_val; } } @@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at(&src, block_byte_base + 108); + + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at(&src, block_byte_base + 96); + scale_vals[1] = load_u32_at(&src, block_byte_base + 100); + scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; var is: u32 = 0; var m: u32 = 1; + // 2 halves of the block (128 elements each) for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { // 4 groups (each group has 2 blocks of 16 elements) @@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sc = get_byte(scale_vals[is / 4], is % 4); is++; let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { + + for (var l: u32 = 0; l < 16; l++) { let q_idx = q_b_idx + k + l; let hm_idx = k + l; let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); let qs_val = (q_byte >> shift) & 3; dst[dst_i] = (f32(qs_val) - hm) * dl; @@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at(&src, block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + for (var i: u32 = 0; i < 16u; i++) { + qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at(&src, aux0_offset); + let aux1 = load_u32_at(&src, aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif + + #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at(&src, block_byte_base + 66), + load_u32_at(&src, block_byte_base + 70) ); + for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); let db = array( @@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at(&src, block_byte_base + 66); + qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at(&src, block_byte_base + 74); + scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); let db = array( @@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at(&src, sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at(&src, block_byte_base + 66), + load_u32_at(&src, block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at(&src, block_byte_base + 106); + for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); let db = array( @@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 5b9f5b36224..fdabaf09b2e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_0 = src0[src0_idx_base + offset]; - let d = f32(block_q4_0.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_0 = src0[src0_idx_base + offset]; - let d = f32(block_q5_0.d); + let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let qh_packed = load_u32_at(&src0, block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_0 = src0[src0_idx_base + offset]; - let d = f32(block_q8_0.d); + let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; @@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at(&src0, block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); + scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); + scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at(&src0, aux0_offset); + let aux1 = load_u32_at(&src0, aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at(&src0, block_byte_base + 66), + load_u32_at(&src0, block_byte_base + 70) ); + var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); @@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); + qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); + scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at(&src0, sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at(&src0, block_byte_base + 66), + load_u32_at(&src0, block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index ea91c13468f..374137ff8e8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); + let d = load_f16_at(&src0, block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let qh_packed = load_u32_at(&src0, block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); + let qh_packed = load_u32_at(&src0, block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); + let d = load_f16_at(&src0, block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base + 80u); - let dmin = load_src0_f16_at(block_byte_base + 82u); + let d = load_f16_at(&src0, block_byte_base + 80u); + let dmin = load_f16_at(&src0, block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base + 108u); + let d = load_f16_at(&src0, block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let dmin = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let dmin = load_f16_at(&src0, block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); } // Map k_in_block to loop structure: @@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let dmin = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let dmin = load_f16_at(&src0, block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); } // The original loop processes elements in groups of 64 @@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_src0_u32_at(block_byte_base + ql13_flat); + let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_src0_u32_at(block_byte_base + ql24_flat); + let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat); + let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx); + let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_src0_f16_at(block_byte_base + 208u); + let d = load_f16_at(&src0, block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6525f23bdfc..6f6bcaf7940 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); + let d = f32(load_f16_at(&src0, block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = f32(load_f16_at(&src0, block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let qh_packed = load_u32_at(&src0, block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = load_f16_at(&src0, block_byte_base + 2u); + let qh_packed = load_u32_at(&src0, block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); + let d = f32(load_f16_at(&src0, block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(bbase + 208u)); + let d = f32(load_f16_at(&src0, bbase + 208u)); - let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); - let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); - let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); - let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); - let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); + let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); + let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 21beb9bb94d..8c334817ccd 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx])); #endif #ifdef EXP - let res = exp(src[params.offset_src + src_idx]); + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32)); #endif #ifdef LOG let res = TYPE(log(f32(src[params.offset_src + src_idx]))); @@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0)); #endif #ifdef EXPM1 - let res = exp(src[params.offset_src + src_idx]) - 1.0; + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32) - 1.0); #endif #ifdef FLOOR let res = floor(src[params.offset_src + src_idx]); @@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; #endif #ifdef SQRT - let res = sqrt(src[params.offset_src + src_idx]); + let res = TYPE(sqrt(f32(src[params.offset_src + src_idx]))); #endif #ifdef SIN let res_f32 = sin(f32(src[params.offset_src + src_idx])); From 2580cfc70360cddcc09de271c84c90c57771e30c Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Fri, 10 Apr 2026 10:52:38 -0700 Subject: [PATCH 115/249] ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (llama/21669) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 13 +++++-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 34 +++++++++---------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3b894a9b9cc..e979783f020 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); #ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now + // Accept f16 subgroup matrix configurations (square or non-square). + // NVIDIA GPUs typically report square configs (e.g. 16x16x16), + // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). + // The shaders are already parameterized to handle any M/N/K dimensions. bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; @@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { break; } + // Head dimensions must be divisible by subgroup matrix dimensions + if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || + src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { + break; + } // Head dimensions must fit in workgroup memory with minimum tile sizes size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 8b76cecba91..aa2d2e54db9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #endif for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); + var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); - var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); + var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); + var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); + var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif var t: u32 = 1u; for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { let h0 = t * SG_MAT_K; - var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); + var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); + var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); #else - var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); + var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q0; k_cur = k0; let h1 = (t + 1u) * SG_MAT_K; - var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); + var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); + var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); #else - var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); + var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q1g; @@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // handle odd tail if (t < HEAD_DIM_QK / SG_MAT_K) { let h = t * SG_MAT_K; - var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); + var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); + var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); #else - var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); + var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = qn; @@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, head_dim_block < HEAD_DIM_V; head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, head_dim_block, false, @@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; - var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &inter_shmem, p_offset, false, @@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef KV_DIRECT let v_block_row = kv_tile + kv_block * SG_MAT_N; let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &V, v_global_offset, false, @@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); #else let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, v_block_offset + head_dim_block, false, From 28ce072f59523b0a3a1752ceab7516e6e5d9a86d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Fri, 10 Apr 2026 15:47:43 -0700 Subject: [PATCH 116/249] hexagon: improved Op queuing, buffer and cache management (llama/21705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hexagon: introduce op request batching and rewrite buffer managment The host now prepares batches of requests and dispatches them via a single dspqueue message. Buffers are mapped explicitly by NPU while processing batches. * hex-dma: disable l2 bypass since to work around new issue due to no flushes between Ops * hex-utils: add explicit l2flush and l2clear helpers * hex-opreq: use fine-grain per tensor l2 management * hex-opreq: avoid redundant invalidates for tensors we already flushed * hex-opreq: update debug messages * htp-opreq: reuse ops_context * hex-opreq: do not flush or invalidate cache lines beyond buffer boundry * hex-opreq: fix errors in log message * Revert "hex-opreq: do not flush or invalidate cache lines beyond buffer boundry" This reverts commit 8b7f0a55a750a6430ce4eb1874c7feb3d720056d. * hexagon: limit l2 flushes to 1MB which covers l2 cache * hex-opreq: limit cache flush to 4MB Looks like 4MB cont. vitual space should cover the 1MB cache. * hexagon: drop cache flush size to 2MB * hex-opreq: start reworking opreq packing * hex-opreq: introduce new way of packing opbatch where tensors are stored separately * hex-opreq: add a simple fastrpc call to force unmap all buffers * hex-l2flush: somehow 2MB does not seem robust, also cleanup step size to use line-size * hex-opreq: bump opreq batch size to 256 * hex-mm: place src1 spad at the top of vtcm for easy reuse * hex-ops: introduce internal types and disable src1 reuse for now Nothing new just formalizing the repack / qyn.quant types we've been using. * htp-opreq: use tensor pointers instead of copies * hex-opreq: introduce more robust way for tracking vtcm/spad reuse This removes the SKIP_QUANTIZE flag that became fragile with the addition of HMX and other ops. * hex-cumsum: fix error post opreq merge * hex-opreq: move request batch handling into the session Prepping everything for using dspqueue buffers and doing that inside the session is much cleaner. * hex-mm: yet another fix for src1 reuse when we're mixing hmx/hvx * hex-bufs: introduce pinned mmapings and use non-pinned ones for model buffers * hex-buf: add support for allocating shared/pinned buffer for opreqs * hex-opbatch: make opbatches configurable * hex-naming: better name for ggml_hexagon_shared_buffer * hex-naming: add session->c_name() helper * hex-opbatch: start using shm but still copy for now * hex-opbatch: use shared buffer for packing opbatch * hex-opbatch: beter naming for opbatch related classes and code * hex-opbatch: reuse batched tensors with same data/dims/strides * hex-opbatch: update logging * hex-opbatch: add support for vmem limit for op batching * hex-opbatch: update htp side to properly support dynamic mmap/unmap * hex-opbatch: add OB and OQ params for run-completion script and fix the asserts in batch processing * hex-opbatch: fixed src1 handling in act ops * hex-act: fix empty src1 handling in swiglu and friends Simplify preamble macro while at it * hex-mm: minor fix vtcm and dma handling in matmul cleaning up some left-overs from merges * hex-opbatch: allocate extra 1KB for dspqueue overhead * hexagon: fix softmax for non-aligned tensors and cleanup vtcm alloc * hex-mm: properly handle hmx_disabled flag * hex-ops: update comments * hex-ops: add debug output for get/set-rows * hex-mmap: optimize un/mapping of buffers * hex-opreq: global cache flush and invalidate beyond 128KB threshold * hex-ops: add super simple opfilter regex for debugging If an Op matches the regex hex backend will reject it. * hex-opbatch: wireup newer ops missed in merge and update main switch to detect this in future * hexagon: improved vtcm acquision to remove inter-op overhead Fully compatible with QNN-HTP coex * hex-mm: fixed hvx fallback path * hex-mm: lower the vmem threshold a bit further to ~3GB * hexagon: update debug & error logs This also fixes an issue with newer llvm merging repack and non-repack functions. We use those pointer to distinguish between buffer types. * hexagon: move ops context into main context Just a cleanup. We don't need separate contexts at this point. * hex-opbatch: cleanup naming and headers for opbatch and related descriptors * hex-fa: it's now better to enable FA during TG to reduce graph splits * hexagon: remove GGML_HEXAGON_EXPERIMENTAL env var It's no longer useful. Please use more flexible GGML_HEXAGON_OPFILTER to disable Ops if needed for debugging or validation. * hexagon: fixed editorconfig check * Update ggml/src/ggml-hexagon/ggml-hexagon.cpp Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Trivikram Reddy Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 1343 +++++++++--------- ggml/src/ggml-hexagon/htp/act-ops.c | 137 +- ggml/src/ggml-hexagon/htp/argsort-ops.c | 18 +- ggml/src/ggml-hexagon/htp/binary-ops.c | 46 +- ggml/src/ggml-hexagon/htp/cpy-ops.c | 10 +- ggml/src/ggml-hexagon/htp/cumsum-ops.c | 25 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 36 +- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 74 +- ggml/src/ggml-hexagon/htp/hex-utils.h | 21 + ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 25 +- ggml/src/ggml-hexagon/htp/hmx-ops.h | 6 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 106 +- ggml/src/ggml-hexagon/htp/htp-msg.h | 166 --- ggml/src/ggml-hexagon/htp/htp-ops.h | 183 ++- ggml/src/ggml-hexagon/htp/htp_iface.idl | 2 + ggml/src/ggml-hexagon/htp/main.c | 1418 +++++--------------- ggml/src/ggml-hexagon/htp/matmul-ops.c | 229 +++- ggml/src/ggml-hexagon/htp/repeat-ops.c | 10 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 31 +- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 90 +- ggml/src/ggml-hexagon/htp/softmax-ops.c | 252 ++-- ggml/src/ggml-hexagon/htp/ssm-conv.c | 21 +- ggml/src/ggml-hexagon/htp/sum-rows-ops.c | 12 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 12 +- 24 files changed, 1732 insertions(+), 2541 deletions(-) delete mode 100644 ggml/src/ggml-hexagon/htp/htp-msg.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ac5baa2acaf..3d68b80048f 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -7,10 +7,14 @@ #include #include -#include #include +#include +#include #include #include +#include +#include +#include #ifdef _WIN32 # include @@ -33,7 +37,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "op-desc.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -44,12 +48,14 @@ static int opt_etm = 0; static int opt_verbose = 0; static int opt_profile = 0; static int opt_hostbuf = 1; // hostbuf ON by default -static int opt_experimental = 0; static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_COMPUTE; +static int opt_opsync = 0; // synchronous ops +static int opt_opbatch = 1024; // max number of ops in a batch +static int opt_opqueue = 16; // max number of pending batches +static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) @@ -86,7 +92,7 @@ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_t op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { @@ -94,7 +100,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, @@ -103,25 +109,16 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); } // ** backend sessions -struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); - ~ggml_hexagon_session() noexcept(true); - - void allocate(int dev_id) noexcept(false); - void release() noexcept(true); - - void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); - void flush(); - - ggml_backend_buffer_type buffer_type = {}; - ggml_backend_buffer_type repack_buffer_type = {}; +struct ggml_hexagon_opbatch; +struct ggml_hexagon_opshm; +struct ggml_hexagon_session { std::string name; remote_handle64 handle; dspqueue_t queue; @@ -133,87 +130,28 @@ struct ggml_hexagon_session { bool valid_handle; bool valid_queue; bool valid_iface; - std::atomic op_pending; - uint32_t prof_usecs; - uint32_t prof_cycles; - uint32_t prof_pkts; -}; - -void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { - // Bump pending flag (cleared in the session::flush once we get the response) - this->op_pending++; // atomic inc - - int err = dspqueue_write(this->queue, - 0, // flags - the framework will autoset this - n_bufs, // number of buffers - bufs, // buffer references - sizeof(req), // Message length - (const uint8_t *) &req, // Message - DSPQUEUE_TIMEOUT // Timeout - ); - - if (err != 0) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err); - } - - if (sync) { - flush(); - } -} - -// Flush HTP response queue i.e wait for all outstanding requests to complete -void ggml_hexagon_session::flush() { - dspqueue_t q = this->queue; - - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. - - while (this->op_pending) { - struct htp_general_rsp rsp; - uint32_t rsp_size; - uint32_t flags; - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; + std::atomic op_pending; + ggml_hexagon_opbatch *op_batch; + ggml_hexagon_opshm *op_shm; - // Read response packet from queue - int err = dspqueue_read(q, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT); // Timeout - - if (err == AEE_EEXPIRED) { - // TODO: might need to bail out if the HTP is stuck on something - continue; - } + ggml_backend_buffer_type buffer_type = {}; + ggml_backend_buffer_type repack_buffer_type = {}; - if (err != 0) { - GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); - } + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); + ~ggml_hexagon_session() noexcept(true); - // Basic sanity checks - if (rsp_size != sizeof(rsp)) { - GGML_ABORT("ggml-hex: dspcall : bad response (size)\n"); - } + const char* c_name() const { return name.c_str(); } - if (rsp.status != HTP_STATUS_OK) { - GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status)); - // TODO: handle errors - } + void allocate(int dev_id) noexcept(false); + void release() noexcept(true); - // TODO: update profiling implementation, currently only works for opt_opsync mode - this->prof_usecs = rsp.prof_usecs; - this->prof_cycles = rsp.prof_cycles; - this->prof_pkts = rsp.prof_pkts; + void enqueue_op(htp_op_code opcode, const ggml_tensor *op); + void flush(bool all = true); - this->op_pending--; // atomic dec - } -} + void flush_pending(bool all = false); + void flush_batch(); +}; // ** backend buffers @@ -227,82 +165,99 @@ struct ggml_backend_hexagon_buffer_type_context { std::string name; }; -struct ggml_backend_hexagon_buffer_context { - bool mmap_to(ggml_hexagon_session * s) { - HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n", - s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd, - (int) this->repack); +struct ggml_hexagon_shared_buffer { + ggml_hexagon_session * sess; + uint8_t * base; + size_t size; + int fd; + bool mapped; + bool pinned; - int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD); + void mmap(bool pinned = false) { + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", - s->domain_id, this->size, this->fd, (unsigned) err); - return false; + GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), + sess->domain_id, this->size, this->fd, (unsigned) err); + throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - return true; - } - - bool mmap() { - if (this->mapped) { - return true; - } - if (!mmap_to(this->sess)) { - return false; + if (pinned) { + err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), + sess->domain_id, this->size, this->fd, (unsigned) err); + throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)"); + } } + this->mapped = true; - return true; + this->pinned = pinned; + HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", + sess->c_name(), (void *) this->base, this->size, this->fd, pinned); } - void munmap() { - if (!this->mapped) { - return; - } + void unmap() { + if (!this->mapped) return; + + htp_iface_munmap(sess->handle, this->fd); + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); + + HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); - fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size); this->mapped = false; + this->fd = -1; } - ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { - size += 4 * 1024; // extra page for padding + void alloc(size_t size, bool pinned = false) { + if (this->base) return; - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); if (!this->base) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->c_name(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); } this->fd = rpcmem_to_fd(this->base); if (this->fd < 0) { - GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base); - rpcmem_free(this->base); - this->base = NULL; + GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->c_name(), (void *) this->base); throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)"); } + this->size = size; + + HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), + (void *) this->base, this->size, this->fd, (int) pinned); + + mmap(pinned); + } + + void free() { + if (!this->base) return; - HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(), - (void *) this->base, size, this->fd, (int) repack); + unmap(); + rpcmem_free(this->base); + + HEX_VERBOSE("ggml-hex: %s freed buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); + + this->base = NULL; + } + + ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { + size += 4 * 1024; // extra page for padding this->sess = sess; - this->size = size; + this->size = 0; + this->base = nullptr; + this->fd = -1; this->mapped = false; - this->repack = repack; - } - ~ggml_backend_hexagon_buffer_context() { - munmap(); - if (this->base) { - rpcmem_free(this->base); - this->base = NULL; - } + alloc(size, pinned); } - ggml_hexagon_session * sess; // primary session - uint8_t * base; - size_t size; - int fd; - bool mapped; // mmap is done - bool repack; // repacked buffer + ~ggml_hexagon_shared_buffer() { + free(); + } }; static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) { @@ -310,30 +265,26 @@ static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_ } static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) { - auto ctx = static_cast(buffer->context); - delete ctx; + auto sbuf = static_cast(buffer->context); + delete sbuf; } static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) { - auto ctx = static_cast(buffer->context); - return ctx->base; + auto sbuf = static_cast(buffer->context); + return sbuf->base; } static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - auto ctx = static_cast(buffer->context); - auto sess = ctx->sess; + auto sbuf = static_cast(buffer->context); + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(), - tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage, - (int) ctx->repack); + HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d\n", sess->c_name(), + tensor->name, (void *) sbuf->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage); if (tensor->view_src != NULL && tensor->view_offs == 0) { - ; // nothing to do for the view - } else { - if (!ctx->mapped) { - ctx->mmap(); - } + return GGML_STATUS_SUCCESS; // nothing to do for the view } + return GGML_STATUS_SUCCESS; } @@ -1387,11 +1338,10 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, const void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1430,11 +1380,10 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1478,10 +1427,10 @@ static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t bu } static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; - HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size); - memset(ctx->base, value, ctx->size); + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; + HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->c_name(), (void *) sbuf->base, sbuf->size); + memset(sbuf->base, value, sbuf->size); } static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { @@ -1508,10 +1457,10 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (host): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1520,10 +1469,10 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (repack): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1538,7 +1487,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1 * 1024 * 1024 * 1024; // 1GB per buffer + return 1UL * 1024 * 1024 * 1024; // 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1570,6 +1519,373 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; +// Backend session implementation + +struct ggml_hexagon_opshm { + ggml_hexagon_shared_buffer *sbuf; + + std::vector block_mask; + size_t block_size; + + uint8_t * base() const { return this->sbuf->base; } + int fd() const { return this->sbuf->fd; } + size_t n_blocks() const { return this->block_mask.size(); } + + ggml_hexagon_opshm(ggml_hexagon_session *sess, size_t max_batch, size_t max_pending) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = max_batch; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + block_mask.resize(max_pending, true); + + block_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops; + + sbuf = new ggml_hexagon_shared_buffer(sess, block_size * block_mask.size(), true /* pinned */); + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated shared buf %zu : block-size %zu max-batch %zu max-pending %zu\n", + sess->c_name(), (size_t) sbuf->size, block_size, max_batch, max_pending); + } + } + + ~ggml_hexagon_opshm() { + delete sbuf; + } + + uint8_t * allocate() { + auto it = std::find(block_mask.begin(), block_mask.end(), true); + if (it == block_mask.end()) + return nullptr; + + unsigned int i = std::distance(block_mask.begin(), it); + uint8_t* addr = sbuf->base + (i * block_size); + block_mask[i] = false; + + HEX_VERBOSE("ggml-hex: %s allocated op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); + return addr; + } + + void release(uint8_t * addr) { + int i = (addr - sbuf->base) / block_size; + block_mask[i] = true; + HEX_VERBOSE("ggml-hex: %s released op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); + } +}; + +struct ggml_hexagon_opbatch { + const char* name; + + std::vector buffers; + std::vector tensors; + std::vector ops; + + std::unordered_map b_map; // buffer fd to index + std::unordered_map t_map; // tensor ptr to index + std::unordered_multimap d_map; // tensor data to index + + unsigned int n_bufs; // num buffers in the batch + unsigned int n_tens; // num tensors ... + unsigned int n_ops; // num ops ... + size_t b_vmem; // sum of all buffer sizes + + unsigned int n_bufs_max; + unsigned int n_tens_max; + unsigned int n_ops_max; + size_t b_vmem_max; + + void reset() { + n_bufs = 0; + n_tens = 0; + n_ops = 0; + b_vmem = 0; + + b_map.clear(); + t_map.clear(); + d_map.clear(); + } + + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t max_batch) { + name = sess->c_name(); + + n_bufs_max = HTP_OP_MAX_BUFS; + n_ops_max = max_batch; + n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; + + b_vmem_max = HTP_OP_MAX_VMEM; + + buffers.resize(n_bufs_max); + tensors.resize(n_tens_max); + ops.resize(n_ops_max); + + b_map.reserve(n_bufs_max); + t_map.reserve(n_tens_max); + d_map.reserve(n_tens_max); + + reset(); + } + + bool empty() const { return n_ops == 0; } + + // add buffer and return its index + int add_buffer(ggml_hexagon_shared_buffer * sbuf) { + // Lookup by fd + auto it = b_map.find(sbuf->fd); + if (it != b_map.end()) { return it->second; } + + // Add new buffer to the batch + int bi = n_bufs++; + GGML_ASSERT(n_bufs < HTP_OP_MAX_BUFS); + + b_map.insert({sbuf->fd, bi}); + + htp_buf_desc &b = buffers[bi]; + b.base = (uint64_t) sbuf->base; + b.fd = sbuf->fd; + b.size = sbuf->size; + + b_vmem += b.size; + + HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); + + return bi; + } + + bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { + return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); + } + + // add tensor and return its index + int add_tensor(const ggml_tensor * t) { + auto sbuf = static_cast(t->buffer->context); + + // First lookup by tensor data + auto range = d_map.equal_range(t->data); + for (auto it = range.first; it != range.second; ++it) { + htp_tensor * h = &tensors[it->second]; + if (same_shape(h, t)) { return it->second; } + } + + // Lookup by tensor ptr + auto it = t_map.find(t); + if (it != t_map.end()) { return it->second; } + + // Add new tensor to the batch + int ti = n_tens++; + GGML_ASSERT(n_tens <= n_tens_max); + + t_map.insert({t, ti}); + d_map.insert({t->data, ti}); + + uint64_t t_offset = (uint8_t *) t->data - sbuf->base; + size_t t_size = ggml_nbytes(t); + + htp_tensor &h = tensors[ti]; + h.bi = add_buffer(sbuf); + h.data = t_offset; + h.size = t_size; + h.type = t->type; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + + h.flags = 0; + if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + h.flags |= HTP_TENSOR_COMPUTE; + } + + HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, + (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + + return ti; + } + + bool fit_op(const struct ggml_tensor *t) const { + if (n_ops >= n_ops_max ) return false; + + // check how much extras we will need + size_t extra_bufs = 0; + size_t extra_vmem = 0; + size_t extra_tens = 0; + + auto fit_tensor = [&](const ggml_tensor *t) { + if (!t_map.count(t)) { + extra_tens++; + + auto sbuf = static_cast(t->buffer->context); + if (!b_map.count(sbuf->fd)) { + extra_vmem += sbuf->size; + extra_bufs += 1; + } + } + }; + + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS && t->src[i]; i++) { + fit_tensor(t->src[i]); + } + fit_tensor(t); + + if ((extra_bufs + n_bufs) > n_bufs_max) return false; + if ((extra_tens + n_tens) > n_tens_max) return false; + if ((extra_vmem + b_vmem) > b_vmem_max) return false; + + return true; + } + + // assumes that fit_op() was called first and returned true + void add_op(htp_op_code opcode, const struct ggml_tensor * t) { + // Add new op + htp_op_desc &o = ops[n_ops++]; + GGML_ASSERT(n_ops <= n_ops_max); + + memcpy(&o.params, &t->op_params, sizeof(t->op_params)); + o.opcode = opcode; + o.flags = 0; + + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + ggml_hexagon_dump_op_exec(name, t, o.flags); + + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { + o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; + } + o.dst = add_tensor(t); + } + + size_t flush(uint8_t * mem_addr, size_t mem_size) { + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + + const size_t b_size = sizeof(htp_buf_desc) * n_bufs; + const size_t t_size = sizeof(htp_tensor) * n_tens; + const size_t o_size = sizeof(htp_op_desc) * n_ops; + + const size_t m_size = b_size + t_size + o_size; + GGML_ASSERT(m_size <= mem_size); + + uint8_t * b_ptr = (uint8_t *) mem_addr; + uint8_t * t_ptr = (uint8_t *) b_ptr + b_size; + uint8_t * o_ptr = (uint8_t *) t_ptr + t_size; + + memcpy(b_ptr, (void *) buffers.data(), b_size); + memcpy(t_ptr, (void *) tensors.data(), t_size); + memcpy(o_ptr, (void *) ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s flush-opbatch : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu\n", + name, n_bufs, n_tens, n_ops, b_vmem, b_size, t_size, o_size); + + if (opt_verbose > 1) { + htp_buf_desc *b = (htp_buf_desc*) b_ptr; + for (unsigned int i=0; i < n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", name, i, + b[i].fd, (void *) b[i].base, (size_t) b[i].size); + } + htp_tensor *t = (htp_tensor*) t_ptr; + for (unsigned int i=0; i < n_tens; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", + name, i, t[i].bi, t[i].data, t[i].size, + (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); + } + } + + reset(); + + return m_size; + } +}; + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush_pending(bool all) { + while (this->op_pending) { + struct htp_opbatch_rsp rsp; + uint32_t rsp_size; + uint32_t flags; + + struct dspqueue_buffer dbuf; + uint32_t n_dbufs; + + // Read response packet from queue + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + if (err == AEE_EEXPIRED) { + continue; + } + + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); + } + + // Basic sanity checks + if (rsp_size != sizeof(rsp) || n_dbufs != 1) { + GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); + } + + op_shm->release((uint8_t*) dbuf.ptr); + + if (rsp.status != HTP_STATUS_OK) { + GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); + // TODO: handle errors + } + + // FIXME: profile will be per opreq + // this->prof_usecs = rsp.prof_usecs; + // this->prof_cycles = rsp.prof_cycles; + // this->prof_pkts = rsp.prof_pkts; + + this->op_pending--; // atomic dec + + if (!all) break; + } +} + +void ggml_hexagon_session::flush_batch() { + if (op_batch->empty()) { return; } + + htp_opbatch_req req; + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; + + dspqueue_buffer dbuf; + dbuf.fd = op_shm->fd(); + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.ptr = op_shm->allocate(); + if (!dbuf.ptr) { + flush_pending(false); + dbuf.ptr = op_shm->allocate(); + } + + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) op_shm->base(); + dbuf.size = op_batch->flush((uint8_t*) dbuf.ptr, op_shm->block_size); + + // Bump pending flag (cleared in the session::flush once we get the response) + this->op_pending++; // atomic inc + + HEX_VERBOSE("ggml-hex: %s: queue-opbatch : %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); + if (err != 0) { + GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); + } +} + +void ggml_hexagon_session::enqueue_op(htp_op_code opcode, const ggml_tensor *op) { + if (!op_batch->fit_op(op)) { + flush_batch(); + } + op_batch->add_op(opcode, op); +} + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush(bool all) { + flush_batch(); + flush_pending(all); +} + void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_session = false; this->valid_handle = false; @@ -1582,9 +1898,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->name = std::string("HTP") + std::to_string(dev_id); this->op_pending = 0; - this->prof_usecs = 0; - this->prof_cycles = 0; - this->prof_pkts = 0; GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); @@ -1676,11 +1989,14 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } } + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; + const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; + // Now let's setup the DSP queue err = dspqueue_create(this->domain_id, 0, // Flags - 128 * 1024, // Request queue size (in bytes) - 64 * 1024, // Response queue size (in bytes) + req_q_size, // Request queue size (in bytes) + rsp_q_size, // Response queue size (in bytes) nullptr, // Read packet callback (we handle reads explicitly) nullptr, // Error callback (we handle errors during reads) (void *) this, // Callback context @@ -1715,6 +2031,10 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; + + // Allocate buffers and state for op batching + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); + this->op_shm = new ggml_hexagon_opshm(this, opt_opbatch, opt_opqueue); } void ggml_hexagon_session::release() noexcept(true) { @@ -1722,6 +2042,9 @@ void ggml_hexagon_session::release() noexcept(true) { int err; + delete this->op_batch; + delete this->op_shm; + // Stop the DSP-side service and close the queue if (this->valid_iface) { err = htp_iface_stop(this->handle); @@ -1753,6 +2076,9 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n buffer_type.device = dev; repack_buffer_type.device = dev; + op_batch = nullptr; + op_shm = nullptr; + try { allocate(dev_id); @@ -1815,9 +2141,13 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - return opt_experimental; -} + if (dst->ne[2] != 1 || dst->ne[3] != 1) { + // FA during prompt still needs work + return false; + } + return true; +} static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -2082,6 +2412,23 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } } + // Reject non-HVX-aligned sizes when ne[0] > HVX_F32_LANES + // The HVX softmax implementation has issues with tail handling for larger non-aligned sizes + // Small sizes (ne[0] <= 32) work correctly with tail-only processing + const int64_t ne0 = src0->ne[0]; + if (ne0 > 32 && (ne0 & (32 - 1)) != 0) { + return false; + } + + // HVX vector size constraints for softmax + #define SOFTMAX_MAX_ROW_SIZE 131072 // 128K elements max for numerical precision + + // Reject very large row sizes to avoid numerical precision issues + // Softmax accumulation over many elements can lead to precision loss + if (ne0 > SOFTMAX_MAX_ROW_SIZE) { + return false; + } + return true; } @@ -2249,571 +2596,85 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } -enum dspqbuf_type { - DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, - DSPQBUF_TYPE_CPU_WRITE_DSP_READ, - DSPQBUF_TYPE_CONSTANT, -}; - -static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) { - if (opt_verbose < 2) return; - - auto buf = static_cast(t->buffer->context); - auto sess = buf->sess; - - GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), - t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, - (unsigned int) d->size); -} - -// Init hexagon tensor from GGML tensor and Hexagon buffer -static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) { - h->data = 0; // updated by the receiver - h->type = t->type; - h->ne[0] = t->ne[0]; - h->ne[1] = t->ne[1]; - h->ne[2] = t->ne[2]; - h->ne[3] = t->ne[3]; - h->nb[0] = t->nb[0]; - h->nb[1] = t->nb[1]; - h->nb[2] = t->nb[2]; - h->nb[3] = t->nb[3]; -} - -static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) { - if (!t) { - return 0; - } - - auto buf = static_cast(t->buffer->context); - - memset(d, 0, sizeof(*d)); - d->fd = buf->fd; - d->ptr = t->data; - d->offset = (uint8_t *) t->data - buf->base; - d->size = ggml_nbytes(t); - - if (!d->size) { - // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty - d->size = 64; - } - - switch (type) { - case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: - // Flush CPU - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER; - break; - case DSPQBUF_TYPE_CPU_WRITE_DSP_READ: - // Flush CPU, Invalidate DSP - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - break; - default: - // Constant buffer, no cache maintenance - d->flags = 0; - break; - } - - htp_req_tensor_init(h, t); - - dspqbuf_dump(d, t, type); - - return 1; -} - -typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op); - -template -static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) { - uint64_t t = ggml_time_us(); - - // Construct HTP request - htp_general_req req; - memset(&req, 0, sizeof(req)); - - req.flags = flags; - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - ggml_hexagon_dump_op_exec(sess->name, op, req.flags); - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - size_t n_bufs = _init_req_func(&req, bufs, op); - sess->enqueue(req, bufs, n_bufs, opt_opsync); - } - - t = ggml_time_us() - t; - - ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t); -} - -template -static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT: - req->op = HTP_OP_MUL_MAT; - break; - case GGML_OP_MUL: - req->op = HTP_OP_MUL; - break; - case GGML_OP_ADD: - req->op = HTP_OP_ADD; - break; - case GGML_OP_SUB: - req->op = HTP_OP_SUB; - break; - case GGML_OP_DIV: - req->op = HTP_OP_DIV; - break; - default: - GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); - break; - } - - // src0: Weights (mulmat) or First Operand (binary op). - // If constant (e.g. weights), no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_CPY; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - // CONT is just a contiguous copy — reuse CPY op - req->op = HTP_OP_CPY; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_REPEAT; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_CUMSUM; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_GET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_ARGSORT; - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -template -static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT_ID: - req->op = HTP_OP_MUL_MAT_ID; - break; - case GGML_OP_ADD_ID: - req->op = HTP_OP_ADD_ID; - break; - default: - GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op); - } - - // src0: Weights (mulmat) or Input Activations (other op). - // If constant, no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - // src2: Expert IDs (mulmat) or Activated Experts (other op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; +static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { + auto sess = static_cast(backend->context); + return sess->c_name(); } -static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; +static void ggml_backend_hexagon_free(ggml_backend_t backend) { + // we just need to delete the backend here + // the sessions are allocated & freed as part of the registry + delete backend; } -static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - - bool supported = false; - +static htp_op_code op_remap_to_htp(const ggml_tensor * t) { switch (t->op) { - case GGML_OP_RMS_NORM: - req->op = HTP_OP_RMS_NORM; - supported = true; - break; - - case GGML_OP_SCALE: - req->op = HTP_OP_SCALE; - supported = true; - break; - - case GGML_OP_SQR: - req->op = HTP_OP_SQR; - supported = true; - break; - - case GGML_OP_SQRT: - req->op = HTP_OP_SQRT; - supported = true; - break; + case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; + case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; + case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; + case GGML_OP_MUL: return HTP_OP_MUL; + case GGML_OP_ADD: return HTP_OP_ADD; + case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; + case GGML_OP_SUB: return HTP_OP_SUB; + case GGML_OP_DIV: return HTP_OP_DIV; + case GGML_OP_CPY: return HTP_OP_CPY; + case GGML_OP_CONT: return HTP_OP_CPY; + case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; + case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; + case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; + case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_SCALE: return HTP_OP_SCALE; + case GGML_OP_SQR: return HTP_OP_SQR; + case GGML_OP_SQRT: return HTP_OP_SQRT; + case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; + case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; + case GGML_OP_ROPE: return HTP_OP_ROPE; + case GGML_OP_REPEAT: return HTP_OP_REPEAT; + case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { - case GGML_UNARY_OP_SILU: - req->op = HTP_OP_UNARY_SILU; - supported = true; - break; - case GGML_UNARY_OP_GELU: - req->op = HTP_OP_UNARY_GELU; - supported = true; - break; - case GGML_UNARY_OP_SIGMOID: - req->op = HTP_OP_UNARY_SIGMOID; - supported = true; - break; - case GGML_UNARY_OP_NEG: - req->op = HTP_OP_UNARY_NEG; - supported = true; - break; - case GGML_UNARY_OP_EXP: - req->op = HTP_OP_UNARY_EXP; - supported = true; - break; - case GGML_UNARY_OP_SOFTPLUS: - req->op = HTP_OP_UNARY_SOFTPLUS; - supported = true; - break; + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; default: break; } break; case GGML_OP_GLU: - if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) { - req->op = HTP_OP_GLU_SWIGLU; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { - req->op = HTP_OP_GLU_SWIGLU_OAI; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) { - req->op = HTP_OP_GLU_GEGLU; - supported = true; + switch (ggml_get_glu_op(t)) { + case GGML_GLU_OP_SWIGLU: return HTP_OP_GLU_SWIGLU; + case GGML_GLU_OP_SWIGLU_OAI: return HTP_OP_GLU_SWIGLU_OAI; + case GGML_GLU_OP_GEGLU: return HTP_OP_GLU_GEGLU; + default: break; } break; - case GGML_OP_SOFT_MAX: - req->op = HTP_OP_SOFTMAX; - supported = true; - break; - default: - break; - } - - if (!supported) { - GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op); + GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(t)); } - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_SUM_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_ROPE; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; + return HTP_OP_INVALID; } -static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_FLASH_ATTN_EXT; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SSM_CONV; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { - auto sess = static_cast(backend->context); - return sess->name.c_str(); -} - -static void ggml_backend_hexagon_free(ggml_backend_t backend) { - // we just need to delete the backend here - // the sessions are allocated & freed as part of the registry - delete backend; -} - -// Map weight type to its activation quantization family. -// Types in the same family produce identical Q8 formats in VTCM and can -// safely share quantized activation data via SKIP_QUANTIZE. -// When adding a new quantized type, assign it the correct family here. -static inline int act_quant_family(enum ggml_type wtype) { - switch (wtype) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_MXFP4: - return 1; // Q8x4x2 - default: - return 0; // unknown / not quantized - } -} - -static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && - act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) && - act_quant_family(op0->src[0]->type) != 0); -} - -static inline bool is_compute_op(ggml_tensor *node) +static inline bool op_is_compute(ggml_tensor *node) { return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } -// scan the graph and figure out last compute op index -static inline int last_compute_op(ggml_cgraph * graph) { - int last = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (is_compute_op(graph->nodes[i])) { - last = i; - } - } - - return last; -} - static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast(backend->context); - HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes); - - const int last = last_compute_op(graph); - - const struct ggml_tensor * prev_op = nullptr; // prev executed op + HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * node = graph->nodes[i]; - - if (!is_compute_op(node)) { - continue; - } - - uint32_t flags = 0; - - // skip quantizer if src1 is reused - if (op_reuse_src1(node, prev_op)) { - flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - - prev_op = node; - - // ask for early notification for the last Op - if (i == last) { - flags |= HTP_OPFLAGS_EARLY_WAKEUP; - } - - switch (node->op) { - case GGML_OP_MUL_MAT: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op>(sess, node, flags); - } - break; - case GGML_OP_MUL_MAT_ID: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op>(sess, node, flags); - } - break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_DIV: - ggml_hexagon_dispatch_op>(sess, node, flags); - break; - case GGML_OP_ADD_ID: - ggml_hexagon_dispatch_op>(sess, node, flags); - break; - case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_SQR: - case GGML_OP_SQRT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_SUM_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_EXP: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_SOFTPLUS: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - default: - break; - } - break; - case GGML_OP_GLU: - switch (ggml_get_glu_op(node)) { - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - default: - break; - } - break; - case GGML_OP_SOFT_MAX: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_ROPE: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_FLASH_ATTN_EXT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_SET_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_GET_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CPY: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CONT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_REPEAT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_ARGSORT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_SSM_CONV: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CUMSUM: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - default: - GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); + ggml_tensor * n = graph->nodes[i]; + if (op_is_compute(n)) { + sess->enqueue_op(op_remap_to_htp(n), n); } } @@ -2826,7 +2687,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { auto sess = static_cast(backend->context); - HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str()); + HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->c_name()); // Wait until all pending ops complete sess->flush(); @@ -3045,7 +2906,7 @@ static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, c static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) { auto sess = static_cast(dev->context); - return sess->name.c_str(); + return sess->c_name(); GGML_UNUSED(dev); } @@ -3056,8 +2917,7 @@ static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev } static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // ~2GB per session for now - *free = 2ULL * 1024 * 1024 * 1024; + *free = 0; *total = *free; GGML_UNUSED(dev); @@ -3172,6 +3032,11 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); + // reject ops that match the filter + if (opt_opfilter && std::regex_match(ggml_op_desc(op), *opt_opfilter)) { + return false; + } + // all srcs & dsts must be mapped to the same session if (!ggml_hexagon_supported_buffers(sess, op)) { ggml_hexagon_dump_op_supp(sess->name, op, false); @@ -3188,6 +3053,13 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = true; break; + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + supp = ggml_hexagon_supported_binary(sess, op); + break; + case GGML_OP_MUL_MAT: supp = ggml_hexagon_supported_mul_mat(sess, op); break; @@ -3196,13 +3068,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_mul_mat_id(sess, op); break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_DIV: - supp = ggml_hexagon_supported_binary(sess, op); - break; - case GGML_OP_ADD_ID: supp = ggml_hexagon_supported_add_id(sess, op); break; @@ -3241,6 +3106,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; } break; + case GGML_OP_GLU: switch (ggml_get_glu_op(op)) { case GGML_GLU_OP_SWIGLU: @@ -3252,6 +3118,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; } break; + case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3438,11 +3305,13 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, "please update hexagon_type to match ggml_type"); - const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL"); const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); @@ -3450,16 +3319,21 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); - opt_experimental = str_experimental ? atoi(str_experimental) : 0; + auto RE_ICASE = std::regex_constants::icase; + + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; opt_verbose = str_verbose ? atoi(str_verbose) : 0; opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; - opt_opsync = str_opsync ? atoi(str_opsync) : 0; + opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; + opt_opsync = str_opsync ? atoi(str_opsync) : opt_opsync; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { opt_ndev = GGML_HEXAGON_MAX_SESSIONS; @@ -3472,12 +3346,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_arch = strtoul(str_arch, NULL, 0); } - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1; - reg->context = new ggml_hexagon_registry(reg); - - HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), - sizeof(struct htp_general_rsp)); } static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = { diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index d8b924981e0..6416d2dfbc3 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -14,59 +14,42 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define htp_act_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; - -#define htp_act_preamble2 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_act_preamble \ + const struct htp_tensor * src0 = actx->octx->src[0]; \ + const struct htp_tensor * src1 = actx->octx->src[1]; \ + const struct htp_tensor * dst = actx->octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 0; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 0; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 0; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 0; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 0; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 0; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 0; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 0; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_act_context { @@ -97,10 +80,7 @@ struct htp_act_context { static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; size_t src0_row_size = actx->src0_row_size; size_t src1_row_size = actx->src1_row_size; @@ -207,10 +187,7 @@ static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -332,9 +309,7 @@ static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, vo static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble2; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -433,9 +408,7 @@ static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble2; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -533,10 +506,7 @@ static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; size_t src0_row_size = actx->src0_row_size; size_t src1_row_size = actx->src1_row_size; @@ -652,9 +622,9 @@ static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * } static int execute_op_activations_f32(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) { FARF(ERROR, "Non-contiguous tensors are not supported at this time \n"); @@ -697,25 +667,20 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; - size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used + size_t src1_row_size = src1 ? src1->nb[1] : src0->nb[1]; size_t dst_row_size = dst->nb[1]; - const bool src1_valid = src1->ne[0]; - if (!src1_valid) { - src1_row_size = src0_row_size; - } - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned; size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row); // Make sure the reserved vtcm size is sufficient - if(vtcm_row_per_thread ==0){ + if (vtcm_row_per_thread == 0) { FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size, spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; @@ -733,7 +698,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (src1->ne[0]) { + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + + if (src1) { FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, @@ -773,9 +742,9 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { // Pointers and GLU logic const uint8_t * data_src0 = (const uint8_t *) src0->data; - const uint8_t * data_src1 = (const uint8_t *) src1->data; + const uint8_t * data_src1 = src1 ? (const uint8_t *) src1->data : NULL; - if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + if (!src1 && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { const int32_t swapped = octx->op_params[1]; data_src1 = data_src0; actx.src1_row_size = actx.src0_row_size; @@ -799,7 +768,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { int op_activations(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_activations_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index 3ec26a4c1ac..bdd0623615d 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -12,7 +12,7 @@ #include "hex-dma.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #ifndef MIN @@ -175,8 +175,8 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = actx->octx; // Unpack context - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; // Scratchpad memory uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; @@ -249,16 +249,16 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { int op_argsort(struct htp_ops_context * octx) { // Check supported types - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3]; const uint32_t n_threads = MIN(total_rows, octx->n_threads); // Allocate scratchpad // We need 1 row of float + 1 row of int32 per thread. - uint32_t ne00 = octx->src0.ne[0]; + uint32_t ne00 = octx->src[0]->ne[0]; size_t values_size = hex_round_up(ne00 * sizeof(float), 128); size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); size_t spad_per_thread = values_size + indices_size; @@ -278,9 +278,9 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = spad_per_thread; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", - octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], - octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], - octx->src0.data, octx->dst.data); + octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], + octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3], + octx->src[0]->data, octx->dst->data); struct htp_argsort_context actx; actx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 1b0f97493bc..52013ad0fec 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -14,7 +14,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #ifndef MIN @@ -43,10 +43,10 @@ struct htp_binary_context { bool split_at_ne02; }; -#define htp_binary_preamble \ - const struct htp_tensor * src0 = &octx->src0; \ - const struct htp_tensor * src1 = &octx->src1; \ - struct htp_tensor * dst = &octx->dst; \ +#define htp_binary_preamble \ + const struct htp_tensor * src0 = octx->src[0]; \ + const struct htp_tensor * src1 = octx->src[1]; \ + const struct htp_tensor * dst = octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -181,7 +181,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -274,7 +274,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -374,7 +374,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -455,7 +455,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -540,7 +540,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; @@ -629,10 +629,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { struct htp_binary_context * bctx = (struct htp_binary_context *) data; struct htp_ops_context * octx = bctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; const uint32_t ne00 = src0->ne[0]; const uint32_t ne01 = src0->ne[1]; @@ -723,15 +723,15 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { } static int execute_op_binary(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); // Use packed row sizes for VTCM allocation - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const size_t src0_row_size = src0->ne[0] * elem_size; const size_t src1_row_size = src1->ne[0] * elem_size; @@ -799,9 +799,9 @@ static int execute_op_binary(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { return HTP_STATUS_OK; @@ -857,12 +857,12 @@ static int execute_op_binary(struct htp_ops_context * octx) { int op_binary(struct htp_ops_context * octx) { // Does not support permutations of src1 - const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src1 = octx->src[1]; if (src1->nb[1] < src1->nb[0]) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { return execute_op_binary(octx); } diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index a40d866b9c3..e5b9d350fd7 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -11,7 +11,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" @@ -32,10 +32,10 @@ struct htp_copy_context { void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; -#define cpy_preamble \ - struct htp_tensor *src0 = &octx->src0; \ - struct htp_tensor *dst = &octx->dst; \ - \ +#define cpy_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c index ce51555a7fd..2ced1971236 100644 --- a/ggml/src/ggml-hexagon/htp/cumsum-ops.c +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -13,9 +13,9 @@ #include "hvx-utils.h" #include "hex-dma.h" -#define htp_cumsum_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_cumsum_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -206,8 +206,8 @@ static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { } int op_cumsum_f32(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { return HTP_STATUS_OK; @@ -226,10 +226,12 @@ int op_cumsum_f32(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = src_row_size_aligned * 2; octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; - octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; - octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; struct htp_cumsum_context cctx = { .octx = octx, @@ -251,8 +253,9 @@ int op_cumsum_f32(struct htp_ops_context * octx) { } int op_cumsum(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; switch (dst->type) { case HTP_TYPE_F32: diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 0c9bc785620..d296a322589 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -15,7 +15,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" // Must be multiple of 32 @@ -278,12 +278,12 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_fa_context * factx = (struct htp_fa_context *) data; const struct htp_ops_context * octx = factx->octx; - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * sinks = octx->src[4]; + const struct htp_tensor * dst = octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -610,11 +610,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } int op_flash_attn_ext(struct htp_ops_context * octx) { - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * dst = octx->dst; // Check support if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { @@ -701,13 +701,11 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; - - // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread); + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 047d2850aaa..5a1dc933860 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -11,7 +11,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" @@ -23,27 +23,33 @@ struct get_rows_context { }; #define get_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ const uint32_t nr = ne10 * ne11 * ne12; static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -51,12 +57,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da struct htp_ops_context * octx = grctx->octx; get_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by src1 elements (which correspond to dst rows) const uint32_t dr = grctx->src1_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); @@ -64,7 +72,7 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; @@ -73,10 +81,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; - const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_get_rows(struct htp_ops_context * octx) { @@ -84,15 +96,15 @@ int op_get_rows(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32) { + if (octx->dst->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -102,8 +114,8 @@ int op_get_rows(struct htp_ops_context * octx) { struct get_rows_context grctx; grctx.octx = octx; - grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 8ed1456bc54..fe0b661e309 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -3,8 +3,10 @@ #include #include +#include #include "hexagon_types.h" +#include "hexagon_protos.h" #include "hex-fastdiv.h" #include "hex-dump.h" @@ -68,4 +70,23 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, Q6_l2fetch_AP((void *) p, control); } +#define HEX_L2_LINE_SIZE 64 +#define HEX_L2_FLUSH_SIZE (128 * 1024) + +static inline void hex_l2flush(void * addr, size_t size) +{ + if (size > HEX_L2_FLUSH_SIZE) { + qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); + } else { + const uint32_t s = (uint32_t) addr; + const uint32_t e = s + size; + for (uint32_t i = s; i < e; i += HEX_L2_LINE_SIZE * 4) { + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 0); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 1); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 2); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 3); + } + } +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 4ff2b36de96..ec191c14981 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -20,7 +20,7 @@ #include "hvx-dump.h" #include "worker-pool.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "hmx-utils.h" #include "hmx-ops.h" @@ -821,7 +821,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // and each q_head is computed individually to avoid tile-major packing // issues. m_chunk_n_rows is always a multiple of 32 (from // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = params->k * sizeof(__fp16); // When the activation has a large stride (e.g. permuted Q tensor with @@ -998,7 +998,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co } // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = k * sizeof(__fp16); // DMA-based activation gather for strided tensors (see batched path comment). @@ -1182,7 +1182,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = k * sizeof(__fp16); const bool use_pipeline = (m >= 128) && (k <= n); @@ -1273,9 +1273,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds void *buf_curr = vtcm_scratch0; void *buf_next = vtcm_scratch1; - // issue async DDR data transfer for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D - // because UDMA roiwidth is 16-bit and total size can exceed 65535. { const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); @@ -1533,20 +1530,15 @@ void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, co worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, - int k, int n, int weight_type) { - // Runtime check -- k >= 16384 exceeds 2D DMA limit - if (k >= 16384) { - FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k); - return -1; - } +int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, + int m, int k, int n, int weight_type) { // assume k % 32 == 0 && n % 32 == 0 const size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t M_BLOCK_SIZE = 512; const size_t N_BLOCK_SIZE = 512; @@ -1576,8 +1568,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", - __func__, m, k, n, weight_type, + FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); // initialize eye tile (32x32 identity matrix) diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index b36c8d129ba..fb95d36f5a9 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -7,16 +7,12 @@ #include #include -#ifndef restrict -# define restrict __restrict -#endif +#include "htp-ops.h" #ifdef __cplusplus extern "C" { #endif -struct htp_context; // forward declaration - typedef struct { float *dst; const float *activation; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 6f1917fa2cb..4c36a6ea0c2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -2,6 +2,7 @@ #define HTP_CTX_H #include "hex-dma.h" +#include "htp-ops.h" #include "worker-pool.h" #include @@ -10,38 +11,85 @@ #include #define HTP_MAX_NTHREADS 10 +#define HTP_MAX_MMAPS 16 + +// Memory mapping +struct htp_mmap { + uint64_t size; + uint64_t base; + uint32_t fd; + uint32_t pinned; +}; + +// Scratchpad state +struct htp_spad { + const struct htp_tensor * src; // original src of the data (for reuse) + uint8_t * data; // pointer to an area in vtcm + uint32_t stride; // stride used inside this spad + uint32_t size; // total size + uint32_t size_per_thread; // size per thread +}; + +// Context while processing an Op +// TODO: fold this into the main context +struct htp_ops_context { + struct htp_context * ctx; + + enum htp_op_code op; // FIXME: rename to opcode + int32_t op_params[HTP_OP_MAX_PARAMS]; + + const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; + const struct htp_tensor * dst; + + // TODO convert these to an array + struct htp_spad src0_spad; + struct htp_spad src1_spad; + struct htp_spad src2_spad; + struct htp_spad src3_spad; + struct htp_spad dst_spad; + + uint32_t n_threads; + uint32_t flags; +}; // Main context for htp DSP backend struct htp_context { - dspqueue_t queue; - dma_queue * dma[HTP_MAX_NTHREADS]; - worker_pool_context_t worker_pool; - uint32_t n_threads; - - int thread_id; - int thread_prio; - - uint8_t * vtcm_base; - size_t vtcm_size; - uint32_t vtcm_rctx; - - atomic_bool vtcm_valid; - atomic_bool vtcm_inuse; - atomic_bool vtcm_needs_release; - - uint32_t opmask; - - // Cached src1 spad position from the last quantize pass. - // When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM - // at this address; the matmul must read from here instead of recomputing - // the offset (which depends on the current op's src0 size). - uint8_t * prev_src1_spad; - - // HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX) -#ifdef HTP_HAS_HMX - int hmx_enabled; // Runtime flag: HMX initialisation succeeded - size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation) -#endif + dspqueue_t queue; + dma_queue * dma[HTP_MAX_NTHREADS]; + struct htp_mmap mmap[HTP_MAX_MMAPS]; + worker_pool_context_t worker_pool; + uint32_t n_threads; + + int thread_id; + int thread_prio; + + int hmx_enabled; + + uint8_t * vtcm_base; + size_t vtcm_size; + uint32_t vtcm_rctx; + atomic_bool vtcm_valid; + atomic_bool vtcm_needs_release; + + struct htp_ops_context octx; }; +int op_matmul(struct htp_ops_context * octx); +int op_matmul_id(struct htp_ops_context * octx); +int op_binary(struct htp_ops_context * octx); +int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); +int op_activations(struct htp_ops_context * octx); +int op_softmax(struct htp_ops_context * octx); +int op_add_id(struct htp_ops_context * octx); +int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); +int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); + #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h deleted file mode 100644 index df0ea7ccbd6..00000000000 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ /dev/null @@ -1,166 +0,0 @@ -#ifndef HTP_MSG_H -#define HTP_MSG_H - -#include - -// ggml-common.h must be included prio to this header - -// Mask to enable various stages of the Ops. -// Used for debugging and profiling. -enum { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize - HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute -}; - -// Op flags -enum { - HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors) - HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling) - HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification -}; - -enum htp_status { - HTP_STATUS_OK = 1, - HTP_STATUS_INTERNAL_ERR = 2, - HTP_STATUS_NO_SUPPORT = 3, - HTP_STATUS_INVAL_PARAMS = 4, - HTP_STATUS_VTCM_TOO_SMALL = 5, -}; - -// The values must match the ggml_type. -// Duplicated here because we can't include full ggml.h in the htp build. -// We have some static_asserts in the cpp code to ensure things are in sync. -enum htp_data_type { - HTP_TYPE_F32 = 0, - HTP_TYPE_F16 = 1, - HTP_TYPE_Q4_0 = 2, - HTP_TYPE_Q8_0 = 8, - HTP_TYPE_IQ4_NL = 20, - HTP_TYPE_I32 = 26, - HTP_TYPE_I64 = 27, - HTP_TYPE_MXFP4 = 39, - HTP_TYPE_COUNT -}; - -// Do not reorder first 4 (used as an index) -enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT, - HTP_OP_MUL_MAT_ID, - HTP_OP_RMS_NORM, - HTP_OP_UNARY_SILU, - HTP_OP_UNARY_GELU, - HTP_OP_UNARY_SIGMOID, - HTP_OP_UNARY_EXP, - HTP_OP_UNARY_NEG, - HTP_OP_UNARY_SOFTPLUS, - HTP_OP_GLU_SWIGLU, - HTP_OP_GLU_SWIGLU_OAI, - HTP_OP_GLU_GEGLU, - HTP_OP_SOFTMAX, - HTP_OP_ADD_ID, - HTP_OP_ROPE, - HTP_OP_FLASH_ATTN_EXT, - HTP_OP_SET_ROWS, - HTP_OP_GET_ROWS, - HTP_OP_SCALE, - HTP_OP_CPY, - HTP_OP_ARGSORT, - HTP_OP_SQR, - HTP_OP_SQRT, - HTP_OP_SUM_ROWS, - HTP_OP_SSM_CONV, - HTP_OP_REPEAT, - HTP_OP_CUMSUM, - INVALID -}; - -static inline size_t htp_t_block_size(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 1; - case HTP_TYPE_F16: - return 1; - case HTP_TYPE_Q4_0: - return QK4_0; - case HTP_TYPE_Q8_0: - return QK8_0; - case HTP_TYPE_IQ4_NL: - return QK4_NL; - case HTP_TYPE_MXFP4: - return QK_MXFP4; - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -static inline size_t htp_type_nbytes(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 4; - case HTP_TYPE_F16: - return 2; - case HTP_TYPE_Q4_0: - return sizeof(block_q4_0); - case HTP_TYPE_Q8_0: - return sizeof(block_q8_0); - case HTP_TYPE_IQ4_NL: - return sizeof(block_iq4_nl); - case HTP_TYPE_MXFP4: - return sizeof(block_mxfp4); - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -// Internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks - -#define HTP_MAX_DIMS 4 - -struct htp_tensor { - uint32_t data; // Buffer offset in the messages, and data pointer on the NSP - uint32_t type; // Data type - uint32_t ne[HTP_MAX_DIMS]; // Number of elements - uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) -}; - -#define HTP_MAX_OP_PARAMS 64 - -struct htp_general_req { - uint32_t op; // GGML/HTP Op - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; - // Params for the op, e.g. epsilon of RMS norm - uint32_t flags; // Request flags - - struct htp_tensor src0; // Input0 tensor - struct htp_tensor src1; // Input1 tensor - struct htp_tensor src2; // Input2 tensor - struct htp_tensor src3; // Input3 tensor - struct htp_tensor src4; // Input4 tensor - struct htp_tensor dst; // Output tensor - - // should be multiple of 64 bytes (cacheline) -}; - -struct htp_general_rsp { - uint32_t op; // GGML/HTP Op - uint32_t status; // HTP_STATUS_... - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint8_t unused[44]; // Pad to 64 bytes -}; - -#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 8 - -#endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index d35decaac20..44a6ab4f737 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -1,65 +1,154 @@ #ifndef HTP_OPS_H #define HTP_OPS_H -#include "htp-ctx.h" -#include "htp-msg.h" -#include "worker-pool.h" - #include -#include -#include +// ggml-common.h must be included prio to this header + +enum htp_status { + HTP_STATUS_OK = 1, + HTP_STATUS_INTERNAL_ERR = 2, + HTP_STATUS_NO_SUPPORT = 3, + HTP_STATUS_INVAL_PARAMS = 4, + HTP_STATUS_VTCM_TOO_SMALL = 5, +}; + +// First set of values must match the ggml_type. +// Duplicated here because we can't include full ggml.h in the htp build. +// We have some static_asserts in the cpp code to ensure things are in sync. +enum htp_data_type { + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_IQ4_NL = 20, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, + + // types used internally for repack, dyn.quant, etc + HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q8_0x4x2, + HTP_TYPE_MXFP4x4x2, + + HTP_TYPE_INVALID +}; + +// Constats for internal types +#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks + + +// Mask to enable various stages of the Ops. +// Used for debugging and profiling. +enum htp_op_mask { + HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) + HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute +}; + +// Do not reorder first 4 (used as an index) +enum htp_op_code { + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, + HTP_OP_REPEAT, + HTP_OP_CUMSUM, + + HTP_OP_INVALID +}; + +#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS +#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS -// ggml-common.h must be included prior to this header +#define HTP_OP_MAX_BUFS 8 +#define HTP_OP_MAX_REQS 256 +#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) +#define HTP_OP_MAX_VMEM (3221225472u) -struct htp_spad { - uint8_t * data; - size_t stride; - size_t size; - size_t size_per_thread; +enum htp_tensor_flags { + HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) + HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU) }; -struct htp_ops_context { - struct htp_context * ctx; +// Tensor descriptor +struct htp_tensor { + uint32_t data; // Buffer offset in the messages, and data pointer on the NPU + uint32_t size; // Data size in bytes + uint32_t flags; // Buffer / tensor flags + uint16_t type; // Data type + uint16_t bi; // Buffer index + uint32_t ne[HTP_OP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_OP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) +}; - enum htp_op op; - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; +// Buffer descriptor +struct htp_buf_desc { + uint64_t base; // base address + uint64_t size; // total size + uint32_t flags; // buffer flags (unused) + uint32_t fd; // file descriptor +}; - struct htp_tensor src0; - struct htp_tensor src1; - struct htp_tensor src2; - struct htp_tensor src3; - struct htp_tensor src4; - struct htp_tensor dst; +enum htp_op_flags { + HTP_OPFLAGS_SKIP_COMPUTE = (1U << 0), // Skip actual computation (used for profiling) +}; - struct htp_spad src0_spad; - struct htp_spad src1_spad; - struct htp_spad src2_spad; - struct htp_spad src3_spad; - struct htp_spad dst_spad; +// Op descriptor +struct htp_op_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t flags; // Op flags + int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices + uint16_t dst; // Output tensor index - worker_pool_context_t * wpool; // worker pool - uint32_t n_threads; // num threads + // the rest is filled in-place by the NPU + uint32_t prof_usecs; // Number of usec per request + uint32_t prof_cycles; // Number of cycles per request + uint32_t prof_pkts; // Number of instruction packets per request + uint32_t unused; +}; - uint32_t flags; +struct htp_opbatch_req { + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of ops + uint32_t flags; // unused + // struct htp_buf_desc bufs[]; -- dspqueue buf 0 + // struct htp_tensor tensors[]; -- dspqueue buf 0 + // struct htp_op_desc ops[]; -- dspqueue buf 0 }; -int op_matmul(struct htp_ops_context * octx); -int op_matmul_id(struct htp_ops_context * octx); -int op_binary(struct htp_ops_context * octx); -int op_unary(struct htp_ops_context * octx); -int op_sum_rows(struct htp_ops_context * octx); -int op_activations(struct htp_ops_context * octx); -int op_softmax(struct htp_ops_context * octx); -int op_add_id(struct htp_ops_context * octx); -int op_rope(struct htp_ops_context * octx); -int op_flash_attn_ext(struct htp_ops_context * octx); -int op_set_rows(struct htp_ops_context * octx); -int op_get_rows(struct htp_ops_context * octx); -int op_cpy(struct htp_ops_context * octx); -int op_repeat(struct htp_ops_context * octx); -int op_argsort(struct htp_ops_context * octx); -int op_ssm_conv(struct htp_ops_context * octx); -int op_cumsum(struct htp_ops_context * octx); +struct htp_opbatch_rsp { + uint32_t status; // HTP_STATUS_... + // struct htp_op_req ops[]; -- dspqueue buf 0 +}; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 2dc716cb441..3eb5d5a6912 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -9,6 +9,8 @@ interface htp_iface : remote_handle64 { AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); + AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); + AEEResult munmap(in uint32 fd); AEEResult enable_etm(); AEEResult disable_etm(); }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 6f37bf9d4b8..8b347039428 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1,5 +1,7 @@ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" #pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" #include #include @@ -12,6 +14,7 @@ #include #include #include +#include #include #include @@ -21,14 +24,10 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "worker-pool.h" -#ifdef HTP_HAS_HMX -#include "hmx-ops.h" -#endif // HTP_HAS_HMX - AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { struct htp_context * ctx; int err = 0; @@ -38,7 +37,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_ENOMEMORY; } - // Use the context structure as a handle + // Use the context structure as the handle *handle = (remote_handle64) ctx; // Enable FARF logs @@ -115,6 +114,16 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_EITEMBUSY; } + // release the mmaps (if any) + for (uint32_t i=0; immap[i].size) { + HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size); + ctx->mmap[i].size = 0; + ctx->mmap[i].base = NULL; + ctx->mmap[i].fd = -1; + } + } + free(ctx); return AEE_SUCCESS; } @@ -143,66 +152,93 @@ AEEResult htp_iface_disable_etm(remote_handle64 handle) { return err; } -static int vtcm_acquire(struct htp_context * ctx) { - int err; - if (!ctx->vtcm_valid) { - // Temporarily bump thread priority to make sure it's higher than other sessions. - // This way the resource manager will notify the other thread to release VTCM. - // Note that we need to reaquire VTCM at normal priority for this to work next time. - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); - if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); - abort(); - } - HAP_compute_res_release_cached(ctx->vtcm_rctx); - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); +AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t pinned) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); - if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); - abort(); + // See if we already have this mapping + for (uint32_t i=0; immap[i]; + if (m->fd == fd) { + m->pinned = pinned; + return AEE_SUCCESS; } - ctx->vtcm_valid = true; } - ctx->vtcm_inuse = true; + // Add new mapping + for (uint32_t i=0; immap[i]; + if (!m->size) { + FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); + void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size); + return AEE_EFAILED; + } + m->base = (uint64_t) va; + m->fd = fd; + m->size = size; + m->pinned = pinned; - return 0; + return AEE_SUCCESS; + } + } + + return AEE_ENOMEMORY; } -static int vtcm_release(struct htp_context * ctx) { - ctx->vtcm_inuse = false; +AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } - if (ctx->vtcm_valid && ctx->vtcm_needs_release) { - ctx->vtcm_valid = false; - ctx->vtcm_needs_release = false; - HAP_compute_res_release_cached(ctx->vtcm_rctx); + for (uint32_t i=0; immap[i]; + if (fd < 0 || m->fd == fd) { + FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size); + HAP_munmap2((void *) m->base, m->size); + m->size = 0; + m->base = NULL; + m->fd = -1; + m->pinned = 0; + } } - return 0; + return AEE_SUCCESS; } -static int vtcm_release_callback(unsigned int rctx, void * state) { - struct htp_context * ctx = (struct htp_context *) state; - - if (!ctx || ctx->vtcm_rctx != rctx) { - return AEE_EBADPARM; - } +static void vtcm_acquire(struct htp_context * ctx) { + if (!ctx->vtcm_valid) { + int err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000u); + if (err != 0) { + FARF(ERROR, "ggml-hex: failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } - // If VTCM is not inuse (not processing Ops) release it right here - // otherwise we'll release it once we're done with the current Op. + ctx->vtcm_needs_release = false; + ctx->vtcm_valid = true; - if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = true; - return 0; + // Drop the priority to make sure we get the release callback from other GGML-HTP and QNN-HTP sessions + HAP_compute_res_update_priority(ctx->vtcm_rctx, ctx->thread_prio + 10); } +} - ctx->vtcm_valid = false; - HAP_compute_res_release_cached(ctx->vtcm_rctx); +static void vtcm_release(struct htp_context * ctx) { + if (ctx->vtcm_valid) { + ctx->vtcm_valid = false; + ctx->vtcm_needs_release = false; + HAP_compute_res_release_cached(ctx->vtcm_rctx); + } +} +static int vtcm_release_callback(unsigned int rctx, void * state) { + struct htp_context * ctx = (struct htp_context *) state; + ctx->vtcm_needs_release = true; return 0; } @@ -236,7 +272,6 @@ static int vtcm_alloc(struct htp_context * ctx) { ctx->vtcm_size = vtcm_size; ctx->vtcm_rctx = rctx; ctx->vtcm_valid = false; - ctx->vtcm_inuse = false; ctx->vtcm_needs_release = false; return 0; @@ -288,18 +323,8 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que } #ifdef HTP_HAS_HMX - if (use_hmx) { - ctx->vtcm_scratch_size = ctx->vtcm_size; - ctx->hmx_enabled = 1; - - FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size); - } else { - // HMX disabled: skip HMX initialisation so the - // dispatch loop falls through to the HVX compute paths. - ctx->hmx_enabled = 0; - ctx->vtcm_scratch_size = ctx->vtcm_size; - FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size); - } + ctx->hmx_enabled = use_hmx; + FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); #endif qurt_sysenv_max_hthreads_t hw_threads; @@ -362,13 +387,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { for (int i = 0; i < ctx->n_threads; i++) { dma_queue_delete(ctx->dma[i]); } + #ifdef HTP_HAS_HMX - if (ctx->hmx_enabled) { - ctx->hmx_enabled = 0; - } + ctx->hmx_enabled = 0; #endif - vtcm_free(ctx); return AEE_SUCCESS; @@ -397,1129 +420,320 @@ static inline void profile_stop(struct profile_data * d) { d->pkts = hex_get_pktcnt() - d->pkts; } -static int send_htp_rsp(struct htp_context * c, - uint32_t op, - uint32_t status, - struct dspqueue_buffer * bufs, - size_t n_bufs, - struct profile_data * prof) { - // Prep response struct (zero-init to clear cmp/unused union) - struct htp_general_rsp rsp; - memset(&rsp, 0, sizeof(rsp)); - rsp.op = op; - rsp.status = status; - rsp.prof_usecs = prof->usecs; - rsp.prof_cycles = prof->cycles; - rsp.prof_pkts = prof->pkts; - - int err = dspqueue_write(c->queue, - 0, // Flags - n_bufs, - bufs, // Buffer references - sizeof(rsp), - (const uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT_NONE); +static int execute_op(struct htp_ops_context * octx) { + switch (octx->op) { + case HTP_OP_MUL_MAT: + return op_matmul(octx); - if (err != 0) { - FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); - } + case HTP_OP_MUL_MAT_ID: + return op_matmul_id(octx); - return err; -} + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + case HTP_OP_DIV: + case HTP_OP_ADD_ID: + return op_binary(octx); -static void proc_matmul_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul(&octx); - vtcm_release(ctx); - } + case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: + case HTP_OP_SQR: + case HTP_OP_SQRT: + case HTP_OP_UNARY_SOFTPLUS: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + return op_unary(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_UNARY_SILU: + case HTP_OP_UNARY_GELU: + case HTP_OP_GLU_SWIGLU: + case HTP_OP_GLU_SWIGLU_OAI: + case HTP_OP_GLU_GEGLU: + return op_activations(octx); -static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_argsort(&octx); - vtcm_release(ctx); - } + case HTP_OP_SOFTMAX: + return op_softmax(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_ROPE: + return op_rope(octx); -static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_cpy(&octx); - vtcm_release(ctx); - } + case HTP_OP_FLASH_ATTN_EXT: + return op_flash_attn_ext(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SET_ROWS: + return op_set_rows(octx); -static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = op_repeat(&octx); - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_GET_ROWS: + return op_get_rows(octx); -static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_get_rows(&octx); - vtcm_release(ctx); - } + case HTP_OP_SUM_ROWS: + return op_sum_rows(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_CPY: + return op_cpy(octx); -static void proc_matmul_id_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul_id(&octx); - vtcm_release(ctx); - } + case HTP_OP_REPEAT: + return op_repeat(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_ARGSORT: + return op_argsort(octx); -static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } + case HTP_OP_SSM_CONV: + return op_ssm_conv(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_CUMSUM: + return op_cumsum(octx); -static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } + case HTP_OP_INVALID: + break; - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} - -static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_unary(&octx); - vtcm_release(ctx); + // No default to catch missing cases } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(ERROR, "Unknown Op %u", octx->op); + return -1; } -static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_sum_rows(&octx); - vtcm_release(ctx); - } +static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct htp_buf_desc *b) { + b->base = NULL; - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} - -static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We've written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup OP context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_ssm_conv(&octx); - vtcm_release(ctx); + for (uint32_t i=0; immap + i; + if (m->size && m->fd == b->fd) { + b->base = m->base; + *m_reuse |= (1 << i); + return true; + } } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + return false; } -static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We've written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_cumsum(&octx); - vtcm_release(ctx); +static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { + if (m->size && !m->pinned) { + FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + HAP_munmap2((void *) m->base, m->size); + m->size = 0; + m->base = 0; + m->fd = -1; } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } -static void proc_activations_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = (n_bufs == 3) ? 2 : 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - if (3 == n_bufs) { - octx.src1 = req->src1; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - if (3 == n_bufs) { - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - } else { - octx.dst.data = (uint32_t) bufs[1].ptr; - } - octx.n_threads = ctx->n_threads; +static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { + if (b->base) return; // already mapped - struct profile_data prof; - profile_start(&prof); + // find unused mapping + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (!m->size) { + void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size); + abort(); // can't do much else at this point + } - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - if (octx.op == HTP_OP_SOFTMAX) { - rsp_status = op_softmax(&octx); - } else { - rsp_status = op_activations(&octx); + m->base = b->base = (uint64_t) va; + m->fd = b->fd; + m->size = b->size; + m->pinned = 0; + + FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + return; } - vtcm_release(ctx); } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } -static void proc_rope_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = n_bufs - 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - if (4 == n_bufs) { - octx.src2 = req->src2; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - if (4 == n_bufs) { - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - } else { - octx.dst.data = (uint32_t) bufs[2].ptr; - } - octx.n_threads = ctx->n_threads; +static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t n_bufs) { + uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) + uint32_t b_reuse = 0; // buf reuse count - struct profile_data prof; - profile_start(&prof); + size_t m_vmem = 0; // mapped vmem + size_t e_vmem = 0; // extra vmem - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_rope(&octx); - vtcm_release(ctx); + // See what we can reuse + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + if (reuse_buf(ctx, &m_reuse, b)) { b_reuse++; } else { e_vmem += b->size; } + FARF(HIGH, "prep-buf #%u : pass0 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + if (b_reuse == n_bufs) return; // all bufs reuse existing mappings -static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_set_rows(&octx); - vtcm_release(ctx); - } + // See how much vmem we have mmaped right now + for (uint32_t i=0; immap[i].size; } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} - -static void proc_flash_attn_ext_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - // Setup Op context - struct htp_ops_context octx; - memset(&octx, 0, sizeof(octx)); - - octx.ctx = ctx; - octx.n_threads = ctx->n_threads; - - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.src3 = req->src3; - octx.src4 = req->src4; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - - int last_buf = 3; - - if (octx.src3.ne[0]) { - octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid - } + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse); - if (octx.src4.ne[0]) { - octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) { + // Drop unused mappings + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + bool used = m_reuse & (1<mmap + i); } + } } - octx.dst.data = (uint32_t) bufs[last_buf].ptr; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_flash_attn_ext(&octx); - vtcm_release(ctx); + // Create missing mappings + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + mmap_buf(ctx, b); + FARF(HIGH, "prep-buf #%u : pass1 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); } +} - profile_stop(&prof); +static void prep_tensor(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t idx, struct htp_tensor *t) { + uint32_t offset = t->data; + uint32_t size = t->size; + uint32_t bi = t->bi; - struct dspqueue_buffer rsp_buf = bufs[last_buf]; - rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + t->data = bufs[bi].base + offset; // update data to the actual pointer - send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); + FARF(HIGH, "prep-tensor #%u: bi %u offset %u size %u data %p : %u:%u:%u:%u", idx, t->bi, offset, t->size, (void*) t->data, + t->ne[0], t->ne[1], t->ne[3], t->ne[3]); } -#ifdef HTP_HAS_HMX -// --------------------------------------------------------------------------- -// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad. -// VTCM, DMA and thread dispatch are managed inside the HMX kernels. -// --------------------------------------------------------------------------- - -static void proc_hmx_matmul_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - // HMX weight tile requires N to be 32-aligned. - if (req->src0.ne[1] % 32 != 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; +static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, struct htp_tensor *tens, uint32_t n_tens) { + for (uint32_t i=0; i < n_tens; i++) { + prep_tensor(ctx, bufs, i, tens + i); } +} - const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 || - req->src1.ne[2] * req->src1.ne[3] > 1); +static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { + memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + octx->flags = op->flags; + octx->op = op->opcode; - // Quantised HMX kernels only handle flat 2D matmul (host already rejects - // batched quantised, but guard here too). F16 batched matmul is handled - // by the dedicated wrapper in hmx-matmul-ops.c. - if (is_batched && - req->src0.type != HTP_TYPE_F16) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } + FARF(HIGH, "proc-op #%u: opcode %u flags 0x%x", idx, octx->op, octx->flags); - // HMX assumes contiguous row-major layout. Fall back for permuted - // tensors where strides are non-monotonic (e.g. transposed KV cache). - if (req->src0.nb[0] > req->src0.nb[1] || - req->src1.nb[0] > req->src1.nb[1]) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } + // Prep input tensors + for (uint32_t i=0; isrc[i] == 0xffff ? NULL : tens + op->src[i]; - // M alignment: when M > 32 but not 32-aligned, we split into - // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). - // When M <= 32 and not 32-aligned, fall back entirely to HVX. - const int m_total = (int) req->src1.ne[1]; - const int m_tail = m_total % 32; - const int m_hmx = m_total - m_tail; + octx->src[i] = src; + if (!src) continue; - if (m_hmx == 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. - // Other types fall back to HVX. - { - uint32_t wtype = req->src0.type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && - wtype != HTP_TYPE_MXFP4) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; + if (!(src->flags & HTP_TENSOR_FLUSHED) && (src->flags & HTP_TENSOR_COMPUTE)) { + // flush compute buffers on input + hex_l2flush((void *) src->data, src->size); } + + FARF(HIGH, "prep-src #%u: data %p size %u : %u:%u:%u:%u", op->src[i], (void*) src->data, src->size, + src->ne[0], src->ne[1], src->ne[3], src->ne[3]); } - (void) n_bufs; - - struct dspqueue_buffer rsp_bufs[1]; - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); - - // src0 = weights, src1 = activation, dst = output - void * wgt = (void *) bufs[0].ptr; - float * act = (float *) bufs[1].ptr; - float * dst = (float *) bufs[2].ptr; - - int k = (int) req->src0.ne[0]; // inner dimension - int n = (int) req->src0.ne[1]; // weight columns - - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - - // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - int ret = -1; - - const int ne02 = (int) req->src0.ne[2]; - const int ne03 = (int) req->src0.ne[3]; - const int ne12 = (int) req->src1.ne[2]; - const int ne13 = (int) req->src1.ne[3]; - // Row strides in elements. For compact tensors these equal k; for - // permuted attention views they can be larger, so pass the real stride. - const int act_stride = (int)(req->src1.nb[1] / sizeof(float)); - const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16)); - - switch (req->src0.type) { - case HTP_TYPE_F16: - if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { - .dst = dst, - .activation = act, - .permuted_weight = (const __fp16 *) wgt, - .m = m_hmx, - .k = k, - .n = n, - .act_stride = act_stride, - .weight_stride = weight_stride, - .dst_stride = (int)(req->dst.nb[1] / sizeof(float)), - .ne02 = ne02, - .ne03 = ne03, - .ne12 = ne12, - .ne13 = ne13, - .src0_nb2 = req->src0.nb[2], - .src0_nb3 = req->src0.nb[3], - .src1_nb2 = req->src1.nb[2], - .src1_nb3 = req->src1.nb[3], - .dst_nb2 = req->dst.nb[2], - .dst_nb3 = req->dst.nb[3], - }; - ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params); - } else { - ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act, - (const __fp16 *) wgt, - m_hmx, k, n, - act_stride, - weight_stride); - } - break; - default: - ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act, - (const uint8_t *) wgt, - m_hmx, k, n, (int) req->src0.type); - break; - } + // Prep output tensor + struct htp_tensor *dst = tens + op->dst; - if (ret == 0) { - rsp_status = HTP_STATUS_OK; - } else { - FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); - vtcm_release(ctx); - req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE; - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - vtcm_release(ctx); - } + octx->dst = dst; - // --- Phase 2: HVX on the remaining m_tail rows --- - if (m_tail > 0 && rsp_status == HTP_STATUS_OK) { - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; // weights: unchanged - octx.src1 = req->src1; - octx.src1.ne[1] = m_tail; // only tail rows - octx.dst = req->dst; - octx.dst.ne[1] = m_tail; // only tail rows - // Always re-quantize tail src1: HMX Phase 1 overwrites VTCM, - // so any previously cached quantized data (SKIP_QUANTIZE pipeline) - // is invalid. - octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE; - octx.op = req->op; - octx.n_threads = ctx->n_threads; - - // Offset activation and dst pointers past the HMX-processed rows. - // Use nb[1] (row stride in bytes) to compute the byte offset. - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]); - octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]); - - FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p", - m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data); - - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - uint32_t hvx_ret = op_matmul(&octx); - vtcm_release(ctx); - if (hvx_ret != HTP_STATUS_OK) { - FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret); - rsp_status = HTP_STATUS_INTERNAL_ERR; - } - } else { - rsp_status = HTP_STATUS_INTERNAL_ERR; - } - } + FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + + (void) execute_op(octx); - profile_stop(&prof); + // flush buffers on output + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); } -#endif // HTP_HAS_HMX +#define DSPQUEUE_POLL_TIMEOUT_USEC 100 +#define DSPQUEUE_POLL_COUNT 100 static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. This ensures we - // keep the DSP busy as much as possible and avoid waiting for the CPU. + int err; + + uint32_t poll_count = DSPQUEUE_POLL_COUNT; - while (1) { - struct htp_general_req req; - uint32_t req_size; + vtcm_acquire(ctx); - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; - uint32_t flags; + while (!ctx->vtcm_needs_release) { + struct htp_opbatch_req req; + uint32_t r_size = sizeof(req); - // Read packet from queue - int err = dspqueue_read_noblock(queue, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(req), // Max message length - &req_size, // Message length - (uint8_t *) &req); // Message + struct dspqueue_buffer dbuf; + uint32_t n_dbufs = 1; + uint32_t flags = 0; + err = dspqueue_read_noblock(queue, &flags, n_dbufs, &n_dbufs, &dbuf, r_size, &r_size, (uint8_t *) &req); if (err == AEE_EWOULDBLOCK) { - // Consumed all packets available for now - return; + if (--poll_count) { + qurt_sleep(DSPQUEUE_POLL_TIMEOUT_USEC); + continue; + } + break; } if (err != 0) { FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err); - return; + break; } - if (req_size != sizeof(req)) { - FARF(ERROR, "Invalid request size"); + if (r_size < sizeof(req) || n_dbufs != 1) { + FARF(ERROR, "invalid request : size %u n-dbufs %u", r_size, n_dbufs); continue; } - if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) { - // Host wants early notification - dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + const uint32_t n_bufs = req.n_bufs; + const uint32_t n_tens = req.n_tensors; + const uint32_t n_ops = req.n_ops; + + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + + if (dbuf.size < b_size + t_size + o_size) { + FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); + break; } - // Process packet based on its message type - switch (req.op) { - case HTP_OP_MUL_MAT: - if (n_bufs != 3) { - FARF(ERROR, "Bad matmul-req buffer list"); - continue; - } -#ifdef HTP_HAS_HMX - if (ctx->hmx_enabled) { - proc_hmx_matmul_req(ctx, &req, bufs, n_bufs); - } else -#endif - { - proc_matmul_req(ctx, &req, bufs, n_bufs); - } - break; - - case HTP_OP_MUL_MAT_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad matmul-id-req buffer list"); - continue; - } - proc_matmul_id_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - case HTP_OP_DIV: - if (n_bufs != 3) { - FARF(ERROR, "Bad binary-req buffer list"); - continue; - } - proc_binary_req(ctx, &req, bufs); - break; - - case HTP_OP_RMS_NORM: - case HTP_OP_SCALE: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } - - proc_unary_req(ctx, &req, bufs); - break; - - case HTP_OP_SQR: - case HTP_OP_SQRT: - case HTP_OP_UNARY_NEG: - case HTP_OP_UNARY_EXP: - case HTP_OP_UNARY_SIGMOID: - case HTP_OP_UNARY_SOFTPLUS: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } - - proc_unary_req(ctx, &req, bufs); - break; - - case HTP_OP_SUM_ROWS: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } - - proc_sum_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_UNARY_SILU: - case HTP_OP_UNARY_GELU: - if (n_bufs != 2) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_GLU_SWIGLU: - case HTP_OP_GLU_SWIGLU_OAI: - case HTP_OP_SOFTMAX: - case HTP_OP_GLU_GEGLU: - if ((n_bufs != 2) && (n_bufs != 3)) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_ADD_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad add-id-req buffer list"); - continue; - } - proc_add_id_req(ctx, &req, bufs); - break; - - case HTP_OP_ROPE: - if ((n_bufs != 3) && (n_bufs != 4)) { - FARF(ERROR, "Bad rope-req buffer list"); - continue; - } - proc_rope_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_FLASH_ATTN_EXT: - if (!(n_bufs >= 4 && n_bufs <= 6)) { - FARF(ERROR, "Bad flash-attn-ext-req buffer list"); - continue; - } - proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_SET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad set-rows-req buffer list"); - continue; - } - proc_set_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_GET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad get-rows-req buffer list"); - continue; - } - proc_get_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_CPY: - if (n_bufs != 2) { - FARF(ERROR, "Bad cpy-req buffer list"); - continue; - } - proc_cpy_req(ctx, &req, bufs); - break; - - case HTP_OP_REPEAT: - if (n_bufs != 2) { - FARF(ERROR, "Bad repeat-req buffer list"); - continue; - } - proc_repeat_req(ctx, &req, bufs); - break; - - case HTP_OP_ARGSORT: - if (n_bufs != 2) { - FARF(ERROR, "Bad argsort-req buffer list"); - continue; - } - proc_argsort_req(ctx, &req, bufs); - break; - - case HTP_OP_SSM_CONV: - if (n_bufs != 3) { - FARF(ERROR, "Bad ssm-conv-req buffer list"); - continue; - } - proc_ssm_conv_req(ctx, &req, bufs); - break; - - case HTP_OP_CUMSUM: - if (n_bufs != 2) { - FARF(ERROR, "Bad cumsum-req buffer list"); - continue; - } - proc_cumsum_req(ctx, &req, bufs); - break; - - default: - FARF(ERROR, "Unknown Op %u", req.op); - break; + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; + + uint8_t * m_ptr = dbuf.ptr; + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; + + FARF(HIGH, "processing opbatch: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + + prep_op_bufs(ctx, bufs, n_bufs); + prep_tensors(ctx, bufs, tens, n_tens); + + struct htp_ops_context *octx = &ctx->octx; + memset(octx, 0, sizeof(*octx)); + octx->n_threads = ctx->n_threads; + octx->ctx = ctx; + + for (uint32_t i=0; i < n_ops; i++) { + struct profile_data prof; + profile_start(&prof); + + proc_op_req(octx, tens, i, &ops[i]); + + profile_stop(&prof); + ops[i].prof_usecs = prof.usecs; + ops[i].prof_cycles = prof.cycles; + ops[i].prof_pkts = prof.pkts; + } + + // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + + struct htp_opbatch_rsp rsp; + rsp.status = HTP_STATUS_OK; // FIXME + + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); + if (err != 0) { + FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); + break; } } + + vtcm_release(ctx); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 24b7bad6876..bac06693d81 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -16,8 +16,9 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" #include "htp-ops.h" +#include "htp-ops.h" +#include "hmx-ops.h" #define MM_SPAD_SRC0_NROWS 16 #define MM_SPAD_SRC1_NROWS 16 @@ -1897,11 +1898,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict src2 = &octx->src2; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -2223,8 +2224,8 @@ struct mmid_row_mapping { static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -2342,8 +2343,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -2612,7 +2613,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; struct htp_spad * spad = &octx->src0_spad; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; @@ -2659,7 +2660,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint32_t dst_stride = octx->src1_spad.stride; @@ -2701,7 +2702,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint32_t dst_stride = octx->src1_spad.stride; @@ -2800,7 +2801,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx, octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; } -int op_matmul(struct htp_ops_context * octx) { +static int op_matmul_hvx(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; struct htp_matmul_context mmctx_struct = {0}; @@ -2824,7 +2825,7 @@ int op_matmul(struct htp_ops_context * octx) { worker_callback_t quant_job_func; worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; - bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); + bool need_quant = true; if (src0->type == HTP_TYPE_F16) { // Try optimized f16-f16 path first (src1 in VTCM) @@ -2838,7 +2839,7 @@ int op_matmul(struct htp_ops_context * octx) { // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { // Optimized path @@ -2915,32 +2916,170 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse between ops + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; - if (need_quant) { + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (need_quant && !octx->src1_spad.src) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - // Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it - octx->ctx->prev_src1_spad = octx->src1_spad.data; + octx->src1_spad.src = src1; + } + + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + + return HTP_STATUS_OK; +} + +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + +#ifndef HTP_HAS_HMX + return op_matmul_hvx(octx); +#else + if (!octx->ctx->hmx_enabled) { + return op_matmul_hvx(octx); + } + + // HMX weight tile requires N to be 32-aligned. + if (src0->ne[1] % 32 != 0) { + return op_matmul_hvx(octx); + } + + // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // Other types fall back to HVX. + uint32_t wtype = src0->type; + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + return op_matmul_hvx(octx); + } + + // Quantised HMX path requires K aligned to 256 (x4x2 super-block). + // F16 HMX path requires K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + return op_matmul_hvx(octx); + } + + if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + return op_matmul_hvx(octx); + } + + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + // Quantised HMX kernels only handle flat 2D matmul (host already rejects + // batched quantised, but guard here too). F16 batched matmul is handled + // by the dedicated wrapper in hmx-matmul-ops.c. + if (is_batched && src0->type != HTP_TYPE_F16) { + return op_matmul_hvx(octx); + } + + // HMX assumes contiguous row-major layout. Fall back for permuted + // tensors where strides are non-monotonic (e.g. transposed KV cache). + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return op_matmul_hvx(octx); + } + + // M alignment: when M > 32 but not 32-aligned, we split into + // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). + // When M <= 32 and not 32-aligned, fall back entirely to HVX. + const int m_total = (int) src1->ne[1]; + const int m_tail = m_total % 32; + const int m_hmx = m_total - m_tail; + + if (m_hmx == 0) { + return op_matmul_hvx(octx); + } + + // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, + // so any previously cached quantized data is invalid. + octx->src1_spad.src = NULL; + + int k = (int) src0->ne[0]; // inner dimension + int n = (int) src0->ne[1]; // weight columns + + // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- + int ret = -1; + + // Row strides in elements. For compact tensors these equal k; for + // permuted attention views they can be larger, so pass the real stride. + const int act_stride = (int)(src1->nb[1] / sizeof(float)); + const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + + if (src0->type == HTP_TYPE_F16) { + if (is_batched) { + hmx_matmul_w16a32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .permuted_weight = (const __fp16 *) src0->data, + .m = m_hmx, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + } else { + ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, + m_hmx, k, n, act_stride, wgt_stride); + } } else { - // SKIP_QUANTIZE: Q8 data lives at the address written by the previous - // quantize pass. The current op may have a different src0 size (e.g. - // IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong. - octx->src1_spad.data = octx->ctx->prev_src1_spad; + ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, + (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_hmx, k, n, (int) src0->type); } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + if (ret != 0) { + FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); + return op_matmul(octx); } - return HTP_STATUS_OK; + // --- Phase 2: HVX on the remaining m_tail rows --- + if (m_tail > 0) { + // copy of src1 and dst + struct htp_tensor src1_tail = *src1; + struct htp_tensor dst_tail = *dst; + + src1_tail.ne[1] = m_tail; // only tail rows + dst_tail.ne[1] = m_tail; // only tail rows + + // Offset activation and dst pointers past the HMX-processed rows. + // Use nb[1] (row stride in bytes) to compute the byte offset. + src1_tail.data += (uint32_t) m_hmx * src1->nb[1]; + dst_tail.data += (uint32_t) m_hmx * dst->nb[1]; + + octx->src[1] = &src1_tail; + octx->dst = &dst_tail; + + FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data); + return op_matmul_hvx(octx); + } + + return 0; +#endif // HTP_HAS_HMX } int op_matmul_id(struct htp_ops_context * octx) { @@ -2950,7 +3089,7 @@ int op_matmul_id(struct htp_ops_context * octx) { struct htp_matmul_context * mmctx = &mmctx_struct; mmctx->octx = octx; - struct htp_tensor * restrict ids = &octx->src2; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -3003,11 +3142,17 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; + octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; @@ -3031,20 +3176,18 @@ int op_matmul_id(struct htp_ops_context * octx) { } } - // Setup worker pool callbacks - if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->ctx->prev_src1_spad = octx->src1_spad.data; - } else { - octx->src1_spad.data = octx->ctx->prev_src1_spad; + octx->src1_spad.src = src1; } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); - } + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c index 5db06c920e2..a6f2f0ed5f3 100644 --- a/ggml/src/ggml-hexagon/htp/repeat-ops.c +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -12,7 +12,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" struct htp_repeat_context { @@ -32,8 +32,8 @@ struct htp_repeat_context { static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; struct htp_ops_context * octx = rctx->octx; - const struct htp_tensor * src = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; const uint32_t ne00 = src->ne[0]; const uint32_t ne01 = src->ne[1]; @@ -98,8 +98,8 @@ static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * dat } int op_repeat(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; // Validate that dst dims are multiples of src dims if (dst->ne[0] % src0->ne[0] != 0 || diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index ecedadb0fea..1d8b0796bc9 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -15,7 +15,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" // Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h @@ -253,10 +253,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_rope_context * rctx = (struct htp_rope_context *) data; struct htp_ops_context * octx = rctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; htp_rope_preamble; @@ -284,7 +284,7 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { dma_queue * dma_queue = octx->ctx->dma[ith]; const int32_t * pos = (const int32_t *) src1->data; - const float * freq_factors = src2->data ? (const float *) src2->data : NULL; + const float * freq_factors = src2 ? (const float *) src2->data : NULL; uint32_t ir = 0; uint32_t prev_i2 = (uint32_t) -1; @@ -384,10 +384,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { static int execute_op_rope_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; const char * op_type = "rope-f32"; @@ -424,19 +424,16 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - // Assign sizes octx->src0_spad.size_per_thread = src0_spad_per_thread; octx->dst_spad.size_per_thread = dst_spad_per_thread; octx->src0_spad.size = n_threads * src0_spad_per_thread; octx->dst_spad.size = n_threads * dst_spad_per_thread; octx->src1_spad.size = 0; - // Assign pointers - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = NULL; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = NULL; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; - // Fill context struct htp_rope_context rctx; memset(&rctx, 0, sizeof(struct htp_rope_context)); @@ -483,7 +480,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { int op_rope(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_rope_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 4b6967749f8..0def7b408bf 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -14,33 +14,37 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define set_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ - const uint32_t ne1 = octx->dst.ne[1]; \ - \ +#define set_rows_preamble \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ const uint32_t nr = ne01; struct htp_set_rows_context { @@ -56,12 +60,14 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { @@ -70,7 +76,7 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -78,14 +84,18 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; // copy row hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } } } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { @@ -94,12 +104,14 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { @@ -108,7 +120,7 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -116,13 +128,17 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); } } } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_set_rows(struct htp_ops_context * octx) { @@ -130,15 +146,15 @@ int op_set_rows(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -153,7 +169,7 @@ int op_set_rows(struct htp_ops_context * octx) { srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; - switch(octx->dst.type) { + switch(octx->dst->type) { case HTP_TYPE_F32: worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index d6356b9506f..d78bcc0eb24 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -15,68 +15,89 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define htp_softmax_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \ - const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \ - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \ - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \ - \ - const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \ - const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \ - const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \ - const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_softmax_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 1; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 1; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 1; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 1; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 1; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 1; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 1; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 1; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_softmax_context { + struct htp_ops_context * octx; + bool use_f16; bool use_src1; + uint32_t n_head; uint32_t n_head_log2; - float scale; - float max_bias; - float m0; - float m1; + float scale; + float max_bias; + float m0; + float m1; - uint32_t src0_nrows_per_thread; struct fastdiv_values fastdiv_ne01; struct fastdiv_values fastdiv_ne02; struct fastdiv_values fastdiv_ne12; // For mask broadcasting struct fastdiv_values fastdiv_ne13; // For mask broadcasting - size_t spad_stride; - struct htp_ops_context * octx; + uint32_t src0_nrows_per_thread; }; +static void apply_mask(float * restrict wp0, + const float * restrict mp_f32, + const __fp16 * restrict mp_f16, + uint32_t ne00, + float slope, + bool use_f16) { + if (!mp_f32) { + return; + } + if (use_f16) { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } +} + static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; memset(smctx, 0, sizeof(struct htp_softmax_context)); - memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); smctx->n_head = src0->ne[2]; @@ -85,8 +106,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_ smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - smctx->use_src1 = (src1->ne[0] != 0); - smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + smctx->use_src1 = (src1 != 0); + smctx->use_f16 = (src1 != 0) && (src1->type == HTP_TYPE_F16); smctx->octx = octx; @@ -97,8 +118,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_ if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; + const uint32_t ne12 = src1 ? src1->ne[2] : 1; + const uint32_t ne13 = src1 ? src1->ne[3] : 1; if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); @@ -139,10 +160,7 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, } } -static void hvx_fast_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict pad, - const int num_elems) { +static void hvx_fast_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, const int num_elems) { const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_pad = (HVX_Vector *) pad; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; @@ -188,27 +206,20 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, } } -static float hvx_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict spad, - const int num_elems, - const float max) { +static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict spad, const int num_elems, const float max) { hvx_sub_scalar_f32(spad, src, max, num_elems); hvx_exp_f32(dst, spad, num_elems, false); - - float sum = hvx_reduce_sum_f32(dst, num_elems); - - return sum; + return hvx_reduce_sum_f32(dst, num_elems); } static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; struct htp_ops_context * octx = smctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; htp_softmax_preamble3; @@ -223,22 +234,26 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { return; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint64_t qt = HAP_perf_get_qtimer_count(); int is_aligned = 1; int opt_path = 0; + if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) { is_aligned = 0; FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); } + + // Only use the fast path when aligned AND row size is multiple of VLEN (128 bytes) + // The fast path (hvx_fast_softmax_f32) doesn't handle tail elements + // The non-opt path uses hvx_softmax_f32 which properly handles all sizes via its helper functions if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { opt_path = 1; } - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); float * wp0 = (float *) src0_spad_data; float * wp1 = (float *) src1_spad_data; @@ -278,47 +293,29 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { // ALiBi if (i2 != prev_i2) { const uint32_t h = i2; // head - - slope = (smctx->max_bias > 0.0f) ? - h < smctx->n_head_log2 ? - powf(smctx->m0, h + 1) : - powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : - 1.0f; + slope = (smctx->max_bias > 0.0f) ? h < smctx->n_head_log2 ? powf(smctx->m0, h + 1) : powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : 1.0f; prev_i2 = i2; } - float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3); + float * sp = (float *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3); // broadcast the mask across rows - __fp16 * mp_f16 = (smctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (smctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; + __fp16 * mp_f16 = (smctx->use_src1) ? (__fp16 *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; + float * mp_f32 = (smctx->use_src1) ? (float *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, - (const uint8_t *) mp_f32, slope); - } else { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, (const uint8_t *) mp_f32, slope); + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else if (1 == opt_path) { hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); - if (mp_f32) { - if (smctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); } else { + // Non-optimized path: uses HVX helper functions that properly handle all tensor sizes + // including non-multiples of 32 (the HVX vector lane count for f32) + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; @@ -326,54 +323,47 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { } } - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, - ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "softmax-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u : opt %u f16 %u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + ne0, ne1, ne2, ne3, opt_path, smctx->use_f16, (unsigned) qt); } static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; struct htp_softmax_context smctx; const char * op_type = "softmax-f32"; - switch (octx->op) { - case HTP_OP_SOFTMAX: - init_softmax_ctx(&smctx, octx); - break; - - default: - FARF(ERROR, "Unsupported Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; - } + init_softmax_ctx(&smctx, octx); const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // 4 rows per thread, padded to HVX vector size + octx->src0_spad.size_per_thread = hex_round_up(4 * src0_row_size, 128); + octx->src1_spad.size_per_thread = hex_round_up(4 * src1_row_size, 128); + octx->dst_spad.size_per_thread = hex_round_up(4 * dst_row_size, 128); - // Use stride for calculating offset - smctx.spad_stride = hex_round_up(src0_row_size, 128); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - if (src1->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + if (src1) { + FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); @@ -385,19 +375,17 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); - } + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return err; + + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); return err; } @@ -405,7 +393,7 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { int op_softmax(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_softmax_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index 6b035810d57..a28fd03e978 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -16,14 +16,14 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "hex-dma.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -289,9 +289,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { // Compute gather scratchpad size for src0 and src1 const size_t gather_spad_size = n_threads * VLEN * 2; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, @@ -323,8 +323,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { } int op_ssm_conv(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; switch (dst->type) { case HTP_TYPE_F32: diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 352650b689b..874c41ab2ac 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -14,13 +14,13 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define sum_rows_preamble \ - struct htp_tensor *src0 = &octx->src0;\ - struct htp_tensor *dst = &octx->dst; \ - \ +#define sum_rows_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -94,7 +94,7 @@ static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) int op_sum_rows(struct htp_ops_context * octx) { sum_rows_preamble; - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 13d28317d5c..03eccfd55e3 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -16,7 +16,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" struct htp_unary_context { @@ -267,8 +267,8 @@ static void softplus_f32(const float * restrict src, static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; - const struct htp_tensor * src = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; htp_unary_preamble; @@ -387,8 +387,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; const char * op_type = NULL; @@ -490,7 +490,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_unary_f32(octx); break; From 3af7c879bc3317337fb46f5b00ca1702243b8a56 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 11 Apr 2026 10:30:30 +0800 Subject: [PATCH 117/249] CUDA: also store node->src ne/nb for graph equality (llama/21736) --- ggml/src/ggml-cuda/common.cuh | 4 +++- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 56a67f1edc8..8a4246223b5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1185,7 +1185,9 @@ struct ggml_cuda_graph { bool warmup_complete = false; struct node_properties { ggml_tensor node; - void * node_src_data_ptrs[GGML_MAX_SRC]; + void * node_src_data_ptrs[GGML_MAX_SRC]; + int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; }; std::vector node_props; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8613d20b9f9..3113de017f0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3070,16 +3070,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph::node_properties prop = {}; memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor)); - // if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see: - // https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188 for (int j = 0; j < GGML_MAX_SRC; ++j) { - prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr; + if (cgraph->nodes[i]->src[j]) { + prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data; + memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j])); + memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j])); + } } - if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + graph->node_props[i] = prop; res = true; } - graph->node_props[i] = prop; } return res; From 34381b01c44c2fc0c40a00e2086fd6318bd7f570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 11 Apr 2026 08:45:00 +0200 Subject: [PATCH 118/249] ggml : fix a few instances of missing GGML_TYPE_Q1_0 cases (llama/21716) --- ggml/src/ggml-cpu/ops.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0b5d6c6df88..a9bc21da6f0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -664,6 +664,7 @@ void ggml_compute_forward_add( { ggml_compute_forward_add_non_quantized(params, dst); } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1113,6 +1114,7 @@ void ggml_compute_forward_add1( GGML_ABORT("fatal error"); } } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1242,6 +1244,7 @@ void ggml_compute_forward_acc( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4331,6 +4334,7 @@ void ggml_compute_forward_out_prod( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4606,6 +4610,7 @@ void ggml_compute_forward_set( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: From e0c8e505e995a3936998ca36feacaf0cf1950133 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sat, 11 Apr 2026 01:46:19 -0700 Subject: [PATCH 119/249] opencl: add basic support for q5_k (llama/21593) * opencl: add general q5_k mv * opencl: add flattened Q5_K mv and general Q5_K mm * opencl: fix Q5_K unit tests --- ggml/src/ggml-opencl/CMakeLists.txt | 3 + ggml/src/ggml-opencl/ggml-opencl.cpp | 384 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 76 ++++ .../kernels/mul_mm_q5_k_f32_l4_lm.cl | 192 +++++++++ .../ggml-opencl/kernels/mul_mv_q5_k_f32.cl | 187 +++++++++ .../kernels/mul_mv_q5_k_f32_flat.cl | 203 +++++++++ 6 files changed, 1043 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 540942b195d..112c2afe821 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -90,6 +90,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 mul_mv_q4_k_f32_flat + mul_mv_q5_k_f32 + mul_mv_q5_k_f32_flat mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -109,6 +111,7 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm mul_mm_q4_k_f32_l4_lm + mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index f1a28a7f4cd..a581402300a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -541,12 +541,15 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_K_noshuffle; cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; + cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q4_K_f32_flat; + cl_kernel kernel_mul_mv_q5_K_f32; + cl_kernel kernel_mul_mv_q5_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -587,6 +590,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -938,6 +942,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); @@ -1249,6 +1255,39 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1556,6 +1595,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q5_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3530,6 +3586,58 @@ struct ggml_tensor_extra_cl_q4_K { } }; +struct ggml_tensor_extra_cl_q5_K { + // Lower 4 bits of quantized weights. + cl_mem q = nullptr; + // Upper 1 bit of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + // Min for each super block. + cl_mem dm = nullptr; + + size_t size_q = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + size_t size_dm = 0; + + ~ggml_tensor_extra_cl_q5_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + + size_q = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + size_dm = 0; + } +}; + struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; @@ -3945,6 +4053,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -4153,6 +4262,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { delete e; } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K) { + delete e; + } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + delete e; + } } ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { @@ -4245,6 +4360,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q5_K * ggml_opencl_alloc_temp_tensor_extra_q5_K() { + ggml_tensor_extra_cl_q5_K * extra; + if (temp_tensor_extras_q5_K.empty()) { + extra = new ggml_tensor_extra_cl_q5_K(); + } else { + extra = temp_tensor_extras_q5_K.back(); + temp_tensor_extras_q5_K.pop_back(); + } + + temp_tensor_extras_q5_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { ggml_tensor_extra_cl_q6_K * extra; if (temp_tensor_extras_q6_K.empty()) { @@ -4291,6 +4421,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_K_in_use.clear(); + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + temp_tensor_extras_q5_K.push_back(e); + } + temp_tensor_extras_q5_K_in_use.clear(); + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { temp_tensor_extras_q6_K.push_back(e); } @@ -4314,6 +4449,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_q8_0_in_use; std::vector temp_tensor_extras_q4_K; std::vector temp_tensor_extras_q4_K_in_use; + std::vector temp_tensor_extras_q5_K; + std::vector temp_tensor_extras_q5_K_in_use; std::vector temp_tensor_extras_q6_K; std::vector temp_tensor_extras_q6_K_in_use; @@ -5152,6 +5289,97 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q5_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_K(); + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/8; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3*ggml_blck_size(tensor->type)/64); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_qh + size_s + size_d + size_dm == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for dm. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for q (lower 4 bits) + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qh (upper 1 bit) + region.origin = align_to(previous_origin + size_q, backend_ctx->alignment); + region.size = size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + extra->size_q = size_q; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + extra->size_dm = size_dm; + + tensor->extra = extra; + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5658,6 +5886,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl_q5_K * extra = (ggml_tensor_extra_cl_q5_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; @@ -10221,6 +10478,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif @@ -10925,6 +11183,51 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q5_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q6_K: { if (ne11 < 32) { break; @@ -11442,7 +11745,81 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q5_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_q5_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat; @@ -11610,7 +11987,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q6_K) { size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 81fe17fa10f..1bd83d29b3d 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -66,6 +66,17 @@ struct block_q4_K { uchar q[QK_K / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q5_k +//------------------------------------------------------------------------------ +struct block_q5_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar qh[QK_K / 8]; + uchar qs[QK_K / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -546,6 +557,71 @@ kernel void kernel_restore_block_q4_K_noshuffle( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_K +// Convert the block_q5_K format to 5 separate arrays (AOS -> SOA). +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_K( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->qs[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q5_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q5_K( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->qs[i] = q[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl new file mode 100644 index 00000000000..8e191f57e83 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + global uchar * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + int qh_base = (iqs % 16) * 2; + int bit_pos = 2*n + b; + uchar h0 = (src0_qh[ib*32 + qh_base + 0] >> bit_pos) & 1; + uchar h1 = (src0_qh[ib*32 + qh_base + 1] >> bit_pos) & 1; + uchar h2 = (src0_qh[ib*32 + qh_base + 2] >> bit_pos) & 1; + uchar h3 = (src0_qh[ib*32 + qh_base + 3] >> bit_pos) & 1; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)( + ((q.s0 >> (b * 4))&0x0F) | (h0 << 4), + ((q.s1 >> (b * 4))&0x0F) | (h1 << 4), + ((q.s2 >> (b * 4))&0x0F) | (h2 << 4), + ((q.s3 >> (b * 4))&0x0F) | (h3 << 4) + )))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl new file mode 100644 index 00000000000..b2058abc1b6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl @@ -0,0 +1,187 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q5_K * x = (global block_q5_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global uchar * qh = x[ib].qh + 8 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + qh += nb01; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl new file mode 100644 index 00000000000..e353a72be70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl @@ -0,0 +1,203 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q5_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q5K_SIZE 176 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32_flat( + global uchar * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q5K_SIZE; + uint blk = nb01 / BLOCK_Q5K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_qh = (global uchar *)src0_qh + offset_src0*(QK_K/8); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global uchar * qh = (global uchar *)(blk_qh + ib * (QK_K/8)) + 8 * ir; + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + qh += blk*32; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} From c0b46c2f8f3eac135f6f8d32dc3863e0c8898e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 11 Apr 2026 18:52:11 +0200 Subject: [PATCH 120/249] CUDA: skip compilation of superfluous FA kernels (llama/21768) --- ggml/src/ggml-cuda/fattn.cu | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index addf93205ef..ea6607cd337 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; + } else { + GGML_ABORT("fatal error"); } - - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; } if (use_gqa_opt && gqa_ratio > 4) { @@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con return; } - if (use_gqa_opt && gqa_ratio > 1) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; - } + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio > 1) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + } else { + GGML_ABORT("fatal error"); + } } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From b9072073128ebd7bdb98c6d328dce6e38f983109 Mon Sep 17 00:00:00 2001 From: Stephen Cox Date: Mon, 13 Apr 2026 00:15:26 +1200 Subject: [PATCH 121/249] mtmd: add Gemma 4 audio conformer encoder support (llama/21421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mtmd: add Gemma 4 audio conformer encoder support Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer. Architecture: - 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm - Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm - Full self-attention with sinusoidal RPE and sliding window mask (24) - Logit softcapping at 50.0, ClippableLinear clamping - Output: 1024 → 1536 → RMSNorm → multimodal embedder Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a): - HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3 - Standard periodic Hann window (320 samples), zero-padded to FFT size - Semicausal left-padding (frame_length/2 samples) - Frame count matched to PyTorch (unfold formula) - No pre-emphasis, no Whisper-style normalization - Mel cosine similarity vs PyTorch: 0.9998 Key fixes: - Tensor loading dedup: prevent get_tensor() from creating duplicate entries in ctx_data. Fixed with std::set guard. - ClippableLinear clamp_info loading moved after per-layer tensors. - Sliding window mask (24 positions) matching PyTorch context_size. - Skip Whisper normalization for Gemma4 mel output. Tested on E2B and E4B with CPU and Vulkan backends. Transcribes: "Glad to see things are going well and business is starting to pick up" (matching ground truth). Ref: #21325 --- ggml/src/ggml-cuda/ssm-conv.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 69985cd335c..b77cdc1c137 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int switch (nc) { case 3: launch_kernel(std::integral_constant{}); break; case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now."); + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); } } From 655072cd78a989c0696efaa1e54e616b6b8c2678 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Mon, 13 Apr 2026 07:14:58 +0530 Subject: [PATCH 122/249] sycl: disable Q1_0 in backend and cleanup unused variables (llama/21807) --- ggml/src/ggml-sycl/convert.cpp | 2 +- ggml/src/ggml-sycl/dequantize.hpp | 1 + ggml/src/ggml-sycl/element_wise.cpp | 2 +- ggml/src/ggml-sycl/gated_delta_net.cpp | 10 ++++------ ggml/src/ggml-sycl/ggml-sycl.cpp | 7 +++++++ ggml/src/ggml-sycl/upscale.cpp | 8 ++++---- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d7f60cbc9ea..f12419426ae 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -488,7 +488,7 @@ static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t const int nb = k / QK_NVFP4; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { dequantize_block_nvfp4(vx, y, k); }); } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index f992db33b2d..68c3db30613 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -14,6 +14,7 @@ #define GGML_SYCL_DEQUANTIZE_HPP #include "common.hpp" +#include "convert.hpp" typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index ec0247528c4..249e80c826e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -355,7 +355,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 648455c134b..ebc587524bf 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -176,14 +176,12 @@ static void launch_gated_delta_net(const float * q_d, const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1); const sycl::uint3 rq3_magic = init_fastdiv_values(rq3); - int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; - switch (S_v) { case 16: { constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -194,7 +192,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -205,7 +203,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -217,7 +215,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 989c91a6abb..ea79d2538c1 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4727,12 +4727,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; + // disable Q1_0 until implementation + if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) { + return false; + } + if (a->ne[3] != b->ne[3]) { return false; } ggml_type src0_type = op->src[0]->type; + + // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) { diff --git a/ggml/src/ggml-sycl/upscale.cpp b/ggml/src/ggml-sycl/upscale.cpp index 18c743de447..e42cb419d83 100644 --- a/ggml/src/ggml-sycl/upscale.cpp +++ b/ggml/src/ggml-sycl/upscale.cpp @@ -272,7 +272,7 @@ static void upscale_f32_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); }); } @@ -304,7 +304,7 @@ static void upscale_f32_bilinear_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32_bilinear_antialias( x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); @@ -314,7 +314,7 @@ static void upscale_f32_bilinear_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32_bilinear( x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); @@ -349,7 +349,7 @@ static void upscale_f32_bicubic_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32_bicubic( x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); From 36b7bb3d9576e670037228bf58c768cdeb5ed450 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Mon, 13 Apr 2026 12:13:04 +0900 Subject: [PATCH 123/249] Remove extra conditional check on debug mode. (llama/21798) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e979783f020..634201bc64d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -534,11 +534,7 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, - ctx->debug_host_buf.GetSize())) { - GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n"); - return; - } + ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); From d9ed371c2c50bdaef8134a05694b826e1cb7f7c6 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 13 Apr 2026 11:14:06 +0200 Subject: [PATCH 124/249] CUDA: Limit DeviceSegmentedSort to immediate mode (llama/21718) * CUDA: Limit DeviceSegmentedSort to immediate mode DeviceSegmentedSort is currently not capturable in a cuda graph. Hence, we have to go for the slower DeviceSegmentedRadixSort in that case. Perf numbers on RTX Pro 6000 Blackwell Max-Q: DeviceSegmentedRadixSort in graph mode (i.e. CUDA Graphs) ARGSORT(type=f32,ne=[2048,512,1,1],order=1): 12291 runs - 105.94 us/run - 8192 kB/run - 73.75 GB/s ARGSORT(type=f32,ne=[4096,512,1,1],order=1): 10245 runs - 115.08 us/run - 16384 kB/run - 135.77 GB/s ARGSORT(type=f32,ne=[8192,512,1,1],order=1): 5125 runs - 221.22 us/run - 32768 kB/run - 141.26 GB/s ARGSORT(type=f32,ne=[16384,512,1,1],order=1): 2565 runs - 430.98 us/run - 65536 kB/run - 145.02 GB/s ARGSORT(type=f32,ne=[32768,512,1,1],order=1): 1028 runs - 1185.83 us/run - 131072 kB/run - 105.41 GB/s ARGSORT(type=f32,ne=[65536,512,1,1],order=1): 387 runs - 2748.62 us/run - 262144 kB/run - 90.95 GB/s DeviceSegmentedSort in immediate mode ARGSORT(type=f32,ne=[2048,512,1,1],order=1): 16388 runs - 71.17 us/run - 8192 kB/run - 109.78 GB/s ARGSORT(type=f32,ne=[4096,512,1,1],order=1): 12294 runs - 81.38 us/run - 16384 kB/run - 192.00 GB/s ARGSORT(type=f32,ne=[8192,512,1,1],order=1): 5125 runs - 240.81 us/run - 32768 kB/run - 129.77 GB/s ARGSORT(type=f32,ne=[16384,512,1,1],order=1): 2565 runs - 406.60 us/run - 65536 kB/run - 153.71 GB/s ARGSORT(type=f32,ne=[32768,512,1,1],order=1): 1285 runs - 873.23 us/run - 131072 kB/run - 143.15 GB/s ARGSORT(type=f32,ne=[65536,512,1,1],order=1): 516 runs - 2288.46 us/run - 262144 kB/run - 109.24 GB/s * Add test case for dispatch to DeviceSegmentedRadixSort We currently lack a way to force graph mode in CUDA, patch callback to invoke ggml_backend_compare_graph_backend twice to enforce each test to run in graph mode --- ggml/src/ggml-cuda/argsort.cu | 79 +++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index ed4e5de70f5..0f3f017b534 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, size_t temp_storage_bytes = 0; + bool is_capturing = false; +#ifdef USE_CUDA_GRAPH + // Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does. + // See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149 + // TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release. + cudaStreamCaptureStatus capture_status; + CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); + is_capturing = (capture_status != cudaStreamCaptureStatusNone); +#endif // USE_CUDA_GRAPH + if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - offset_iterator, offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, - stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } @@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream)); } } else { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, offset_iterator, - offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } } From 0f99a47177a887b3771c876fcd5b8de4711c9fe3 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 13 Apr 2026 14:21:31 +0200 Subject: [PATCH 125/249] vulkan: Flash Attention DP4A shader for quantized KV cache (llama/20797) * use integer dot product for quantized KV flash attention * small improvements * fix SHMEM_STAGING indexing * add missing KV type quants * fixes * add supported quants to FA tests * readd fast paths for <8bit quants * fix mmq gate and shmem checks --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 97 +++++++-- .../vulkan-shaders/flash_attn.comp | 184 +++++++++++++++++- .../vulkan-shaders/flash_attn_base.glsl | 5 + .../vulkan-shaders/flash_attn_mmq_funcs.glsl | 149 ++++++++++++++ .../vulkan-shaders/mul_mmq_shmem_types.glsl | 6 + .../src/ggml-vulkan/vulkan-shaders/types.glsl | 1 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 15 +- 7 files changed, 430 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 977aff62d81..1bee3e187cf 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { - GGML_UNUSED(kv_type); vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { result.block_rows /= 2; } @@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->fp16) { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + } } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; @@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const bool mmq = device->integer_dot_product && device->subgroup_clustered && + (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || + kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || + kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * float_type_size; const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + uint32_t Qf, kvsh, kblocksh_size; + if (mmq) { + // block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds + const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size; + Qf = Br * (hsk / 32) * block_b_size; + + // kvsh uses D = HSV (K goes through kblocksh instead) + kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + // block_a_cache size depends on quant type + uint32_t block_a_size; + switch (kv_type) { + case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; + default: block_a_size = 0; break; + } + kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; + } else { + Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + + const uint32_t D = std::max(hsk, hsv); + kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t D = std::max(hsk, hsv); - const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + kblocksh_size = 0; + } - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported); return supported; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 11b7dce8578..6e6bdabc92e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -10,6 +10,13 @@ #extension GL_EXT_shader_subgroup_extended_types_float16 : require #endif +#ifdef MMQ +#extension GL_EXT_integer_dot_product : require +#extension GL_KHR_shader_subgroup_clustered : require + +#include "mul_mmq_shmem_types.glsl" +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; +#ifndef MMQ const uint32_t qf_stride = HSK / 4 + 1; shared FLOAT_TYPEV4 Qf[Br * qf_stride]; +#else +const uint32_t qf_stride = HSK / 32; +shared block_b_cache Qf[Br * qf_stride]; +#endif + +#ifndef MMQ const uint32_t D = HSK > HSV ? HSK : HSV; +#else +const uint32_t D = HSV; +#endif const uint32_t kvsh_stride = D / 4 + 1; shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; +#ifdef MMQ + +shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1]; +#endif + shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; +#ifdef MMQ +#include "flash_attn_mmq_funcs.glsl" +#endif + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -82,10 +108,39 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t r = (idx + tid) / (HSK / 4); - if (r < Br && d < HSK / 4 && - i * Br + r < N) { + const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N; +#ifndef MMQ + if (is_in_bounds) { Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } +#else + const uint buf_ib = r * qf_stride + d / 8; + const uint buf_iqs = d % 8; + + FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f); + const FLOAT_TYPEV4 abs_vals = abs(vals); + + const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8); + const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0); + const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0); + vals = round(vals * qd_inv); + + Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); + +#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } +#else // Q4_0, Q4_1, Q5_0, Q5_1 + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } +#endif +#endif } barrier(); @@ -195,6 +250,7 @@ void main() { if (SHMEM_STAGING != 0) { barrier(); +#ifndef MMQ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -214,9 +270,29 @@ void main() { kvsh[c * kvsh_stride + d] = K_Tf; } } +#else // MMQ + const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint quant_iters = Bc * HSK / 32 * ints_per_block; + [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { + const uint32_t iqs = (idx + tid) % ints_per_block; + const uint32_t ib = (idx + tid) / ints_per_block; + const uint32_t c = ib / (HSK / 32); + const uint32_t block = ib % (HSK / 32); + if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) { + const uint buf_ib = c * qf_stride + block; + if (!KV_bounds_check || j * Bc + c < KV) { + const uint global_ib = (j * Bc + c) * k_stride + block; + k_block_to_shmem(buf_ib, global_ib, iqs, k_offset); + } else { + k_block_to_shmem_zero(buf_ib, iqs); + } + } + } +#endif // MMQ barrier(); } +#ifndef MMQ // More d iterations means Q register caching becomes relevant // Few iterations means the additional registers needed are worse than the speed-up from caching if (HSK_per_thread / 4 > 4) { @@ -275,6 +351,110 @@ void main() { } } } +#else // MMQ + const uint hsk4 = HSK_per_thread / 4; + const uint d_per_step = (hsk4 % 8 == 0) ? 8 : + (hsk4 % 4 == 0) ? 4 : + (hsk4 % 2 == 0) ? 2 : 1; + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + [[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) { + int32_t k_quants[d_per_step]; + ACC_TYPEV2 k_dm; + + if (SHMEM_STAGING != 0) { + const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); +#else + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = kblocksh[buf_ib].qs[d]; + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); + } + } + } else { + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE; + const uint iqs = (coord % BLOCK_SIZE); + +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); +#else + k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { +#if defined(DATA_A_Q5_0) + uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], + k_packed.k_data_packed16[k_offset + ib].qh[1])); +#elif defined(DATA_A_Q5_1) + uint qh = k_packed.k_data_packed16[k_offset + ib].qh; +#endif + [[unroll]] for (uint32_t d = 0; d < 4; d++) { +#if defined(A_TYPE_PACKED32) + uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; +#else + uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], + k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); +#endif + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (qh >> (d * 4)) & 0xF; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); + } + } + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8; + + int32_t acc = 0; + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]); + } + + Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x; + if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) { + Sf[r][c] += k_dot_correction(qib, k_dm); + } + } + } + } +#endif // MMQ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index b30dee86871..6f349246915 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(A_TYPE_PACKED32) +layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; +layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; +#endif + #ifndef BLOCK_SIZE #define BLOCK_SIZE 1 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl new file mode 100644 index 00000000000..e14e62d546a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -0,0 +1,149 @@ +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q4_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + return int32_t(vui & 0x0F0F0F0F); +} +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q5_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], + k_packed.k_data_packed16[a_offset + ib].qh[1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); +} +#endif + +#if defined(DATA_A_IQ4_NL) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + u8vec4 idx = unpack8(vui & 0x0F0F0F0F); + return pack32(i8vec4(kvalues_iq4nl_const[idx.x], + kvalues_iq4nl_const[idx.y], + kvalues_iq4nl_const[idx.z], + kvalues_iq4nl_const[idx.w])); +} +#endif + +#if QUANT_AUXF == 1 +FLOAT_TYPE get_k_d(uint ib, uint a_offset) { + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); +} +#else +FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { + return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); +} +#endif + +void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { +#if defined(DATA_A_Q4_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_Q4_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; +#elif defined(DATA_A_Q5_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], + k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + } +#elif defined(DATA_A_Q5_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; + } +#elif defined(DATA_A_Q8_0) + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_IQ4_NL) + const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], + kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); + kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], + kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); +#endif + + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); +#endif + } +} + +int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + return kblocksh[buf_ib].qs[pos]; +#endif +} + +ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { +#if defined(DATA_A_Q4_0) + return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q5_0) + return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) + return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; +#else + return ACC_TYPE(0.0); +#endif +} + +void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { + kblocksh[buf_ib].qs[iqs] = 0; +#if defined(DATA_A_IQ4_NL) + kblocksh[buf_ib].qs[iqs + 4] = 0; +#endif + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index c700f6e3f25..10552d013a2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -32,6 +32,12 @@ struct block_a_cache { int32_t qs[32/4]; FLOAT_TYPE dm; }; +#elif defined(DATA_A_IQ4_NL) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE dm; +}; #elif defined(DATA_A_MXFP4) #define QUANT_R_MMQ 2 struct block_a_cache { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 4239070af5e..1fb592fb84b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16 #if defined(DATA_A_IQ4_NL) #define QUANT_K QUANT_K_IQ4_NL #define QUANT_R QUANT_R_IQ4_NL +#define QUANT_AUXF 1 #define A_TYPE block_iq4_nl #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77a55ea812b..607eef7d0d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -406,8 +406,8 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix; std::string out_path = join_paths(output_dir, name + ".spv"); if (input_filepath == "") { @@ -625,15 +625,16 @@ void process_shaders() { for (const bool& fp16 : {false, true}) { std::map base_dict; if (fp16) { - base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; } else { - base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}}; } // flash attention for (const bool& f16acc : {false, true}) { std::map fa_base_dict = base_dict; fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2"; fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; if (fp16 && f16acc) { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; @@ -672,6 +673,12 @@ void process_shaders() { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (tname != "f32") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); + } +#endif } } } From cdeaa341742c4d558d7020079ef0e282803511a6 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 14 Apr 2026 11:34:23 +0200 Subject: [PATCH 126/249] vulkan: Support GGML_TYPE_NVFP4 (llama/21455) This adds nvfp4 support for get_rows, dequant, and mul_mat(_id). For mul_mat, it does not add support for the dp4/q8_1 path, it's all via fp16/fp32. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 28 +++++++++++ .../vulkan-shaders/copy_from_quant.comp | 2 +- .../vulkan-shaders/dequant_funcs.glsl | 25 ++++++++++ .../vulkan-shaders/dequant_funcs_cm2.glsl | 20 ++++++++ .../vulkan-shaders/dequant_nvfp4.comp | 32 +++++++++++++ .../vulkan-shaders/mul_mm_funcs.glsl | 17 +++++++ .../src/ggml-vulkan/vulkan-shaders/types.glsl | 47 ++++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 +- 8 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1bee3e187cf..b353d041421 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3079,6 +3079,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec case GGML_TYPE_MXFP4: lut_size = 4*16; break; + case GGML_TYPE_NVFP4: + // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4). + lut_size = 4*16 + 128u * (uint32_t)sizeof(float); + break; default: break; } @@ -3558,6 +3562,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) GGML_ASSERT(device->subgroup_ballot); @@ -3588,6 +3593,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3651,6 +3657,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3674,6 +3681,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } GGML_ASSERT(device->subgroup_ballot); @@ -3708,6 +3716,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else @@ -3773,6 +3782,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3819,6 +3829,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3864,6 +3875,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3939,6 +3951,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3983,6 +3996,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4010,6 +4024,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -4108,6 +4123,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); @@ -4133,6 +4149,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4184,6 +4201,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4239,6 +4257,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4265,6 +4284,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4291,6 +4311,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -6089,6 +6110,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6161,6 +6183,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6227,6 +6250,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6318,6 +6342,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6387,6 +6412,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -15373,6 +15399,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return false; @@ -15488,6 +15515,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_I32: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 06df5095258..6a692147478 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,7 +4,7 @@ #include "generic_unary_head.glsl" #include "dequant_funcs.glsl" -#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) // 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index ede1275cfc2..88d07d2dfd5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint sub = iqs >> 4; + const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]); + const uint j = iqs & 7; + const uint shift = (iqs & 8) >> 1; // 0 or 4 + const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]); + const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]); + const uint qs0 = (vui0 >> shift) & 0xF; + const uint qs1 = (vui1 >> shift) & 0xF; + return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 v1 = dequantize(ib, iqs + 2u, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1.0, 0.0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 03035f28120..c582aba87dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif +#if defined(DATA_A_NVFP4) +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 { + block_nvfp4 block; +}; + +float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint sub = (idx & 0x30) >> 4; + const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7); + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + uint qs = uint(bl.block.qs[iqs]); + qs = (qs >> shift) & 0xF; + return float16_t(kvalues_mxfp4[qs] * d * 0.5); +} +#endif + #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 #elif defined(DATA_A_Q4_0) @@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_NVFP4) +#define dequantFuncA dequantFuncNVFP4 #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp new file mode 100644 index 00000000000..689089160b7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint sub = tid / 16; + const uint ir = tid % 16; + const uint ib = 16 * i + ir; + if (ib >= p.nel / 64) { + return; + } + + const uint q_idx = 8 * sub; + const uint b_idx = 1024 * i + 64 * ir + 16 * sub; + + const float d = ue4m3_to_fp32(data_a[ib].d[sub]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 219bd608035..6e4a29d2fdd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_mxfp4[vui2 & 0xF] * d); buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, kvalues_mxfp4[vui2 >> 4] * d); +#elif defined(DATA_A_NVFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + // lo and hi nibbles are 8 elements apart, which doesn't quite line up with + // how the thread mapping and buf_idx calculation works for other types. + const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2; + + const uint ib = idx / 16u; + const uint sub = (idx & 0xC) >> 2; + const uint iqs = (idx & 0xF) * 2; + const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 1fb592fb84b..4bcd97756fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1713,6 +1713,22 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif +#define QUANT_K_NVFP4 64 +#define QUANT_R_NVFP4 1 + +struct block_nvfp4 +{ + uint8_t d[QUANT_K_NVFP4 / 16]; + uint8_t qs[QUANT_K_NVFP4 / 2]; +}; + +#if defined(DATA_A_NVFP4) +#define QUANT_K QUANT_K_NVFP4 +#define QUANT_R QUANT_R_NVFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_nvfp4 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1732,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize) } #endif -#if defined(DATA_A_MXFP4) +#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) const int8_t kvalues_mxfp4_const[16] = { int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), @@ -1740,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = { shared int8_t kvalues_mxfp4[16]; +#if defined(DATA_A_NVFP4) +// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero. +shared float ue4m3_fp32_lut[128]; + +float ue4m3_to_fp32_build(uint u) { + if (u == 0u || u == 127u) { + return 0.0; + } + const uint exp = (u >> 3) & 15u; + const uint man = u & 7u; + if (exp == 0u) { + return float(man) * (1.0 / 512.0); + } + const uint bits = (exp + 120u) << 23 | (man << 20); + return uintBitsToFloat(bits); +} +#endif + #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { @@ -1747,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize) for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; } +#if defined(DATA_A_NVFP4) + for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) { + ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i); + } +#endif barrier(); } #endif @@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if defined(DATA_A_NVFP4) +float ue4m3_to_fp32(uint8_t x) { + return ue4m3_fp32_lut[uint(x)]; +} +#endif + #if BDA #extension GL_EXT_buffer_reference : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 607eef7d0d6..b232927658b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -66,6 +66,7 @@ const std::vector type_names = { "iq4_xs", "iq4_nl", "mxfp4", + "nvfp4", "bf16", }; @@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_quant = "2"; if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4")) load_vec_quant = "4"; if (tname == "bf16") { From b732f4d9b5429c72f7e50ed7001588a4aa847380 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 03:46:41 -0700 Subject: [PATCH 127/249] ggml-webgpu: Update register tiling matmul to use f32 accumulation (llama/21644) * Update register tiling matmul to use f32 accumulation * fix profiling code * Fix register tiling matmul for chrome, i'm blaming dawn * Update batch tuning value for iOS * compile fix * Fix use of new load function --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 51 +++++++------------ .../wgsl-shaders/mul_mat_decls.tmpl | 35 +++++-------- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 12 ++--- .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 3 ++ 4 files changed, 40 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 634201bc64d..8d0e109365f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -79,7 +79,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u #define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u #define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u #define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) @@ -97,14 +97,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* End Constants */ -static inline wgpu::CallbackMode ggml_webgpu_callback_mode() { -#ifdef __EMSCRIPTEN__ - return wgpu::CallbackMode::AllowProcessEvents; -#else - return wgpu::CallbackMode::AllowSpontaneous; -#endif -} - // This is a "fake" base pointer, since WebGPU buffers do not have pointers to // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -445,34 +437,25 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, } #ifdef __EMSCRIPTEN__ -// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures. EM_JS(int, ggml_webgpu_is_ios_browser, (), { const ua = navigator.userAgent; return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; }); #endif -static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) { +// TODO: these next two functions may want tuning across different platforms and workloads, +static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { #ifdef __EMSCRIPTEN__ + // iOS has very strict limits on the number of in-flight GPU commands, + // so we need to throttle to avoid failures. if (ggml_webgpu_is_ios_browser()) { return 1; } -#else - GGML_UNUSED(info); #endif - return UINT32_MAX; } -static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) { -#ifdef __EMSCRIPTEN__ - if (ggml_webgpu_is_ios_browser()) { - return 16; - } -#else - GGML_UNUSED(info); -#endif - +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() { return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; } @@ -482,7 +465,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( ctx->queue.OnSubmittedWorkDone( - ggml_webgpu_callback_mode(), + wgpu::CallbackMode::AllowSpontaneous, [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -502,7 +485,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, std::string callback_message; const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( - buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(), + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -542,15 +525,15 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { #endif #ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, - const std::vector & commands, - std::vector & futures) { +static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, + const std::vector & commands, + std::vector & futures) { for (const auto & command : commands) { auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(), + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); @@ -3428,7 +3411,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, ggml_webgpu_callback_mode(), + &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); @@ -3449,8 +3432,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } #endif ctx->webgpu_global_ctx->adapter.GetInfo(&info); - ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info); - ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support @@ -3501,7 +3484,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( - ggml_webgpu_callback_mode(), + wgpu::CallbackMode::AllowSpontaneous, [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; @@ -3535,7 +3518,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->adapter.RequestDevice( - &dev_desc, ggml_webgpu_callback_mode(), + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { if (status != wgpu::RequestDeviceStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 374137ff8e8..56a76a6e6c4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -502,12 +502,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let d = load_f16_at(&src0, block_byte_base); let dmin = load_f16_at(&src0, block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); - } - // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) // Inner loop over 2 shifts per group @@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); + let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -578,11 +574,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let d = load_f16_at(&src0, block_byte_base); let dmin = load_f16_at(&src0, block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); - } // The original loop processes elements in groups of 64 // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] @@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); + let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index b1da421a691..ee37e6d249c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -4,14 +4,14 @@ enable f16; #include "mul_mat_decls.tmpl" #ifdef VEC -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); } #endif #ifdef SCALAR -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return f32(acc[tm][tn]); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; } #endif @@ -98,7 +98,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - var acc: array, TILE_M>; + var acc: array, TILE_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { @@ -122,7 +122,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let src1_idx = src1_n * TILE_K + k_inner; let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += src0_tile[tm] * src1_val; + acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 9f9ef279f29..4151ce430b0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -6,6 +6,9 @@ enable chromium_experimental_subgroup_matrix; #include "common_decls.tmpl" #include "mul_mat_decls.tmpl" +// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. +// See https://github.com/ggml-org/llama.cpp/issues/21602 + #ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( From bfdcd4a92c0302905f8c6010642e0e87685d53b1 Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Tue, 14 Apr 2026 05:47:56 -0500 Subject: [PATCH 128/249] cmake: fix CMP0194 warning on Windows with MSVC (llama/21630) * cmake: fix CMP0194 warning on Windows with MSVC Set CMP0194 policy to NEW before project() call in ggml/CMakeLists.txt to suppress the "MSVC is not an assembler for language ASM" warning introduced in CMake 4.1. The ggml project enables ASM globally for Metal (macOS) and KleidiAI (ARM) backends. On Windows/MSVC, no assembler sources are used, but CMake 4.1+ warns because cl.exe is not a valid ASM compiler. This follows the same pattern used in ggml-vulkan (CMP0114, CMP0147). Closes ggml-org/llama.cpp#20311 * cmake: apply cisc's formatting suggestion --------- Co-authored-by: texasich --- ggml/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6bf15723b3c..8454eecde6e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,4 +1,11 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. + +# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html +# MSVC is not a valid assembler for the ASM language. +# Set to NEW to avoid a warning on CMake 4.1+ with MSVC. +if (POLICY CMP0194) + cmake_policy(SET CMP0194 NEW) +endif() project("ggml" C CXX ASM) ### GGML Version From 80f7be74bb45e575f1cf2ab35e1ba8553358694a Mon Sep 17 00:00:00 2001 From: Richard Davison Date: Tue, 14 Apr 2026 13:23:45 +0200 Subject: [PATCH 129/249] ggml : fix ARM NEON nvfp4 dot product on non-dotprod targets (llama/21559) --- ggml/src/ggml-cpu/arch/arm/quants.c | 40 ++++++++++++++++++++++++----- ggml/src/ggml-cpu/ggml-cpu-impl.h | 10 ++++++++ 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index e09db59cf22..64d811fafe7 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); +#if defined(__ARM_FEATURE_DOTPROD) const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); @@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); const int32x4_t p0 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); + vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), + vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); const int32x4_t p1 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), + vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); - const int32x4_t sums = vpaddq_s32(p0, p1); + const int32x4_t sumi = vpaddq_s32(p0, p1); +#else + const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0); + const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0); + const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0); + const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0); + const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1); + const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1); + const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1); + const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1); + + const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs); + const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8); + const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16); + const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24); + const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs); + const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8); + const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16); + const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24); + + const int32x4_t sumi = (int32x4_t){ + vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)), + }; +#endif - // Decode 4 UE4M3 scales to f32 and multiply with q8 scales const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); const float32x4_t nvsc = { @@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo }; const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); - acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales); + acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales); } sumf = vaddvq_f32(acc); #else diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 88a9c9ec057..5d1ca5ffcc3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { #if !defined(__ARM_FEATURE_DOTPROD) +// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter. inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); @@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // !defined(__ARM_FEATURE_DOTPROD) +static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo, + const int8x8_t q4_hi, const int8x8_t q8_hi) { + const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo); + const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi); + const int32x4_t sum_lo = vpaddlq_s16(p_lo); + const int32x4_t sum_hi = vpaddlq_s16(p_hi); + return vaddq_s32(sum_lo, sum_hi); +} + #endif // defined(__ARM_NEON) #ifdef __wasm_simd128__ From 691b1d0826e9a1eceb955b527591aa23c287ebb0 Mon Sep 17 00:00:00 2001 From: Seyoung Jeong Date: Tue, 14 Apr 2026 21:43:59 +0900 Subject: [PATCH 130/249] metal : add XIELU unary op (llama/20802) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 7 +++++++ ggml/src/ggml-metal/ggml-metal.metal | 9 +++++++++ 5 files changed, 19 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e8548b053e8..8e0836c0beb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; + case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 40cacb46520..4c192da650f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 62b028f4a4a..e7433f2a658 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -127,6 +127,7 @@ #define OP_UNARY_NUM_CEIL 118 #define OP_UNARY_NUM_ROUND 119 #define OP_UNARY_NUM_TRUNC 120 +#define OP_UNARY_NUM_XIELU 121 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 846225d9077..5b426be103f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { args.max = ggml_get_op_params_f32(op, 1); } + if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) { + args.slope = ggml_get_op_params_f32(op, 1); // alpha_n + args.scale = ggml_get_op_params_f32(op, 2); // alpha_p + args.bias = ggml_get_op_params_f32(op, 3); // beta + args.val = ggml_get_op_params_f32(op, 4); // eps + } + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); if (pipeline.c4) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f67c5cd8a1d..445a4deca83 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl( if (FC_OP == OP_UNARY_NUM_TRUNC) { dst_ptr[i0] = (T) trunc(x); } + + if (FC_OP == OP_UNARY_NUM_XIELU) { + const TC xi = x; + const TC gate = TC(xi > TC(0.0f)); + const TC clamped = fmin(xi, TC(args.val)); + const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi; + const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi; + dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg); + } } #undef FC_OP From 7024f7e5c12e7b0c42f5edddf69ed3210caf497a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Apr 2026 15:58:09 +0300 Subject: [PATCH 131/249] ci : re-enable mac workflows (llama/21894) * ci : re-enable mac workflows * vulkan : fix compile warning --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 8d0e109365f..aa3fe06d5a9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3485,7 +3485,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( wgpu::CallbackMode::AllowSpontaneous, - [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; } From 45365fa1116f13586a89b9b6ed67e956e5f7399b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 14 Apr 2026 15:17:45 +0200 Subject: [PATCH 132/249] vulkan: Programmatically add RoundingModeRTE to all shaders when the device supports it (llama/21572) * vulkan: Programmatically add RoundingModeRTE to all shaders when the device supports it * use FetchContent to get SPIRV-Headers * Fetch spirv-headers unconditionally * remove fetchcontent, rely on installed headers * fix ubuntu job * Update docs/build.md --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 201 ++++++++++-------- .../vulkan-shaders/copy_to_quant.comp | 1 - ggml/src/ggml-vulkan/vulkan-shaders/diag.comp | 1 - ggml/src/ggml-vulkan/vulkan-shaders/exp.comp | 1 - .../vulkan-shaders/generic_binary_head.glsl | 1 - .../ggml-vulkan/vulkan-shaders/glu_head.glsl | 1 - .../ggml-vulkan/vulkan-shaders/im2col.comp | 1 - .../ggml-vulkan/vulkan-shaders/im2col_3d.comp | 1 - ggml/src/ggml-vulkan/vulkan-shaders/log.comp | 1 - .../ggml-vulkan/vulkan-shaders/multi_add.comp | 1 - .../ggml-vulkan/vulkan-shaders/rope_head.glsl | 1 - .../vulkan-shaders/rope_params.glsl | 2 - ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl | 5 - ggml/src/ggml-vulkan/vulkan-shaders/tri.comp | 1 - .../vulkan-shaders/vulkan-shaders-gen.cpp | 91 +++----- 15 files changed, 138 insertions(+), 172 deletions(-) delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b353d041421..b2a54bd85d0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -20,6 +20,13 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include +// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and +// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan. +#if defined(_WIN32) && !defined(__MINGW32__) +#include +#else +#include +#endif #include #include @@ -2131,6 +2138,66 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + + // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for + // separate shader variants compiled with -DRTE16. + std::vector spv; + if (device->float_controls_rte_fp16) { + const uint32_t* spv_words = reinterpret_cast(spv_data); + size_t word_count = spv_size / sizeof(uint32_t); + spv.assign(spv_words, spv_words + word_count); + + // Find insertion points respecting SPIR-V layout order: + // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ... + size_t pos = 5; // skip header + size_t cap_insert_pos = pos; + size_t ext_insert_pos = pos; + size_t exec_insert_pos = pos; + uint32_t entry_point_id = 0; + + while (pos < spv.size()) { + uint32_t opcode = spv[pos] & spv::OpCodeMask; + uint32_t len = spv[pos] >> spv::WordCountShift; + if (len == 0) break; + + if (opcode == spv::OpCapability) { + cap_insert_pos = pos + len; + ext_insert_pos = pos + len; + } else if (opcode == spv::OpExtension) { + ext_insert_pos = pos + len; + } else if (opcode == spv::OpEntryPoint) { + entry_point_id = spv[pos + 2]; + exec_insert_pos = pos + len; + } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) { + exec_insert_pos = pos + len; + } else if (entry_point_id != 0) { + break; + } + + pos += len; + } + + // Insert from latest position first so earlier indices stay valid. + + // OpExecutionMode %entrypoint RoundingModeRTE 16 + uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 }; + spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); + + // OpExtension "SPV_KHR_float_controls" + const char ext_str[] = "SPV_KHR_float_controls"; + size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t)); + std::vector extension(1 + ext_str_words, 0); + extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension; + memcpy(&extension[1], ext_str, sizeof(ext_str)); + spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end()); + + // OpCapability RoundingModeRTE + uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE }; + spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); + + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data()); + } + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -4344,10 +4411,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); - if (device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -4372,43 +4438,28 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } - -#define SET_ROWS(itype, rte) \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - - if (device->float_controls_rte_fp16) { - SET_ROWS(_i32, _rte) - SET_ROWS(_i64, _rte) - } else { - SET_ROWS(_i32, ) - SET_ROWS(_i64, ) - } + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + +#define SET_ROWS(itype) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + SET_ROWS(_i32) + SET_ROWS(_i64) #undef SET_ROWS @@ -4428,11 +4479,10 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; - bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}, 4) @@ -4475,13 +4525,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4522,19 +4567,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(floor) CREATE_UNARY(trunc) CREATE_UNARY(sgn) + CREATE_UNARY(exp) #undef CREATE_UNARY -#define CREATE_UNARY_RTE(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } - CREATE_UNARY_RTE(exp) -#undef CREATE_UNARY_RTE - ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -4544,13 +4579,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -4583,25 +4613,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); @@ -4663,13 +4682,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->shader_int64 && device->buffer_device_address) { IM2COL(_bda) } else { @@ -14343,8 +14357,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co } // conditions for pipeline creation - if (!(ctx->device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 4ffa45485c9..710c15296da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp index cd3f42f4911..79761324f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index b69d4ddb096..c7cf5ec68f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "generic_head.glsl" #include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index ba7909c4d38..dc657f3c708 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,7 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "utils.glsl" #if RMS_NORM_ROPE_FUSION #include "rope_params.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 95298922d83..d8fdd8f7b5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,5 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index db14f5a3cf3..674f91e5ed2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,7 +3,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 4bf8b4ca046..93f61fd8543 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,7 +4,6 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp index ff2812d3d75..3cda6a63c45 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 10cf5202a4a..26d194e9e8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,7 +8,6 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.glsl" #include "types.glsl" #include "utils.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index d9b4d4c03f3..51a127bcd87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -2,7 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" #include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index ec6ceaca9bd..2e2a7e14c66 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -1,8 +1,6 @@ #if !defined(GGML_ROPE_PARAMS) #define GGML_ROPE_PARAMS -#include "rte.glsl" - struct rope_params { uint rope_mode; uint nrows; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl deleted file mode 100644 index ad51c1e80b8..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +++ /dev/null @@ -1,5 +0,0 @@ - -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp index e18d0ffa307..f9b78f96072 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b232927658b..54b9b327333 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -745,7 +745,7 @@ void process_shaders() { string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); - string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -769,15 +769,12 @@ void process_shaders() { for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { - string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); - string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } auto get_type_str = [](bool f16) { @@ -794,12 +791,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - for (auto rte : {false, true}) { auto source = op == "add_rms" ? std::string("add") : op; - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); auto add_rms = op == "add_rms" ? "1" : "0"; - string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); - } + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}}); } } } @@ -847,14 +842,11 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -908,21 +900,18 @@ void process_shaders() { string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -942,25 +931,18 @@ void process_shaders() { string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); @@ -983,7 +965,6 @@ void process_shaders() { std::string bda_def = bda ? "1" : "0"; string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); - string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); } } @@ -1036,8 +1017,8 @@ void process_shaders() { string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); - string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}}); string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); @@ -1090,8 +1071,8 @@ void write_output_files() { std::string suffixes[2] = {"_f32", "_f16"}; for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) { - hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; - hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + hdr << "extern const void * " << op << "_data[2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2];\n"; std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; if (basename(input_filepath) != op_file) { @@ -1099,8 +1080,8 @@ void write_output_files() { } std::stringstream data = make_generic_stringstream(); std::stringstream len = make_generic_stringstream(); - data << "const void * " << op << "_data[2][2][2][2] = "; - len << "const uint64_t " << op << "_len[2][2][2][2] = "; + data << "const void * " << op << "_data[2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { data << "{"; @@ -1116,20 +1097,10 @@ void write_output_files() { data << "{"; len << "{"; } - for (uint32_t rte = 0; rte < 2; ++rte) { - if (rte == 0) { - data << "{"; - len << "{"; - } - data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - data << "_data,"; - len << "_len,"; - if (rte == 1) { - data << "}, "; - len << "}, "; - } - } + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + data << "_data,"; + len << "_len,"; if (t2 == 1) { data << "}, "; len << "}, "; From 08e412c862ca2274aedb16f314376a81cd32b9a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Apr 2026 17:32:29 +0300 Subject: [PATCH 133/249] metal : fix FA support logic (llama/21898) --- ggml/src/ggml-metal/ggml-metal-device.m | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4c192da650f..effe666a691 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1160,6 +1160,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te if (op->src[1]->type != op->src[2]->type) { return false; } + switch (op->src[1]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + break; + case GGML_TYPE_BF16: + if (!has_bfloat) { + return false; + } + break; + default: + return false; + } return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: From 44d86c4921c1d8ba48e946c1885983311e6055b1 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 14 Apr 2026 16:32:58 +0200 Subject: [PATCH 134/249] ggml : remove ggml-ext.h (llama/21869) * ggml: correct placement of ggml-ext.h * ggml : remove ggml-ext.h --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml-backend.h | 47 ++++++++++++++++++++++++++++++++++ ggml/src/ggml-alloc.c | 1 + ggml/src/ggml-backend-meta.cpp | 3 --- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 3c06aeaffb1..4a8f6d4287d 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -348,6 +348,53 @@ extern "C" { // Set a callback to be called for each resulting node during graph compute GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + // + // Meta backend + // + +#define GGML_BACKEND_META_MAX_DEVICES 16 + + enum ggml_backend_meta_split_axis { + // tensor split by tensor dimensions: + GGML_BACKEND_SPLIT_AXIS_0 = 0, + GGML_BACKEND_SPLIT_AXIS_1 = 1, + GGML_BACKEND_SPLIT_AXIS_2 = 2, + GGML_BACKEND_SPLIT_AXIS_3 = 3, + + GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends + GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum + + // for internal bookkeeping only: + GGML_BACKEND_SPLIT_AXIS_NONE = 98, + GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99, + }; + GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis); + + struct ggml_backend_meta_split_state { + enum ggml_backend_meta_split_axis axis; + + // for tensors with axis >= 0 && axis < GGML_MAX_DIMS: + // - each device has a slice of the tensor along the split axis + // - most tensors have n_segments == 1 and a contiguous slice of the tensor data + // - some tensors have an inhomogenenous data layout along the split axis, + // those tensors are divided into segments which are each individually split across devices + // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t n_segments; + }; + + // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible: + typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); + + // create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this: + // TODO: this looks a bit strange - a backend API creates a device. I think we should try + // express this as a backend registry functionality instead + GGML_API ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud); + // // Utils // diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index e9b70398ffc..a4b01ccf8a1 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -2,6 +2,7 @@ #include "ggml-backend-impl.h" #include "ggml.h" #include "ggml-impl.h" + #include #include #include diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index a2ab8872c4a..0a8eea4e945 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -5,9 +5,6 @@ #include "ggml-alloc.h" #include "ggml-cpp.h" -// TODO: tmp -#include "ggml-ext.h" - #include #include #include From 24cc89e477bea0336e911316d036c24dee5258a8 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:09:03 -0700 Subject: [PATCH 135/249] hexagon: optimization for HMX mat_mul (llama/21554) * hexagon: add async HMX worker Introduce hmx-worker (dedicated thread for HMX compute) to overlap HMX matmul with HVX dequant/DMA stages in the pipeline path, replacing the previous synchronous HMX calls that blocked the main thread. * hexagon: cost-based VTCM chunk search for out-stationary matmul * hexagon: fix futex race in hmx_worker_drain Store the boolean to local variable avoid atomic load twice * hex-mm: hmx optimize scatter/transpose and use HMX intrinsics * hex-vmem: drop vmem limit a touch under 3GB on v73 * hexagon: add fwd declaration of htp_context * hex-hmx: replace hmx-worker with hmx-queue that mimics dma-queue interface Simplifies the overall implemantion, reduces thread wakeup roundtrips. * hex-mm: add debug log to hmx work func called from hmx-queue * Update hmx-queue.h Co-authored-by: Max Krasnyansky --------- Co-authored-by: Kim-Chyan Gan Co-authored-by: Max Krasnyansky Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/hex-utils.h | 15 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 388 +++++++++++++-------- ggml/src/ggml-hexagon/htp/hmx-queue.c | 158 +++++++++ ggml/src/ggml-hexagon/htp/hmx-queue.h | 134 +++++++ ggml/src/ggml-hexagon/htp/hmx-utils.h | 56 --- ggml/src/ggml-hexagon/htp/htp-ctx.h | 7 + ggml/src/ggml-hexagon/htp/htp-ops.h | 5 + ggml/src/ggml-hexagon/htp/hvx-base.h | 5 + ggml/src/ggml-hexagon/htp/main.c | 17 +- 10 files changed, 589 insertions(+), 197 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-queue.c create mode 100644 ggml/src/ggml-hexagon/htp/hmx-queue.h diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2b60f427ada..9ca759459d4 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -47,6 +47,7 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE + hmx-queue.c hmx-matmul-ops.c ) diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index fe0b661e309..f6713c5cf8f 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -31,6 +31,14 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + static inline size_t hmx_ceil_div(size_t num, size_t den) { return (num + den - 1) / den; } @@ -73,8 +81,7 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, #define HEX_L2_LINE_SIZE 64 #define HEX_L2_FLUSH_SIZE (128 * 1024) -static inline void hex_l2flush(void * addr, size_t size) -{ +static inline void hex_l2flush(void * addr, size_t size) { if (size > HEX_L2_FLUSH_SIZE) { qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); } else { @@ -89,4 +96,8 @@ static inline void hex_l2flush(void * addr, size_t size) } } +static inline void hex_pause() { + asm volatile(" pause(#255)\n"); +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index ec191c14981..485ec3f1aa9 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -16,14 +16,16 @@ #include "ggml-common.h" #include "hex-dma.h" +#include "worker-pool.h" + #include "hvx-utils.h" #include "hvx-dump.h" -#include "worker-pool.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "hmx-utils.h" #include "hmx-ops.h" +#include "hmx-utils.h" +#include "hmx-queue.h" #include "hmx-profile.h" static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { @@ -47,7 +49,8 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128, + 24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128 }; // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes @@ -109,36 +112,45 @@ static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { return false; } -// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget. +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +// +// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead // -// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead -// per_n_cost: bytes per nc column (weight + scratch buffers) -// per_m_cost: bytes per mc row (activation) -// per_mn_cost: bytes per mc*nc element (output) -// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.) +// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. +// All matmul paths repeat weight processing per M-block and activation loading +// per N-block, so discrete block counts drive total overhead. +// Tie-break: when cost is equal, prefer larger mc * nc. +// +// Caller-provided coefficients: +// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). +// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). // // Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. // Returns 0 on success, -1 if VTCM is insufficient. -static int hmx_compute_chunks( - size_t vtcm_total, size_t overhead, - size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost, - int m, int n, - size_t *m_chunk_out, size_t *n_chunk_out, - size_t *total_out) -{ +static int hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + int m, + int n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { if (m <= 0 || n <= 0) return -1; if (vtcm_total <= overhead) return -1; if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; const size_t usable = vtcm_total - overhead; - size_t best_mn = 0, best_m = 0, best_n = 0; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { - // Early exit: if nc * m_max cannot beat best, smaller nc won't either - if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn) - break; - size_t n_fixed = 0, ncmn = 0, mc_denom = 0; if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; if (n_fixed >= usable) goto next_nc; @@ -152,10 +164,19 @@ static int hmx_compute_chunks( mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); mc = hex_smin(mc, (size_t)m); - if (mc > 0 && mc * nc > best_mn) { - best_mn = mc * nc; - best_m = mc; - best_n = nc; + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; } } @@ -233,7 +254,7 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_scales = hvx_vec_splat_f16(*scale); // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); // Shuffle before LUT v_quants = Q6_Vb_vshuff_Vb(v_quants); @@ -257,7 +278,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); // Shuffle before LUT @@ -277,10 +298,8 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63] - out[2] = v_hi; // group2 already in [0:63] - out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63] + out[0] = v_lo; // group0 already in [0:63] + out[1] = v_hi; // group2 already in [0:63] } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -384,8 +403,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( size_t row_stride, int weight_type, int start_tile, int end_tile) { - const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; - const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2); + const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : @@ -398,47 +418,46 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) - for (int t = start_tile; t < end_tile; ) { - int ct = t / n_k_tiles; // column tile index - int kt = t % n_k_tiles; // K tile index + unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index + unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index + for (unsigned t = start_tile; t < end_tile; ) { + if (kt >= n_k_tiles) { kt = 0; ct++; } - // --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row --- - if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) && - ((t + 3) / n_k_tiles == ct)) { - int blk_idx = (kt * 32) / QK_Q4_0x4x2; - int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 - bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE - + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- + if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales __fp16 *tile_bases[4]; - for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } + for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - if (row1 < n_cols) { - dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1); - } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); - } + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); } + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + HVX_Vector v0[2]; + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); } + + + r0 = vtcm_src + row_offset; row_offset += row_stride; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } - - t += 4; + t += 4; kt += 4; continue; } @@ -495,20 +514,19 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // --- Single-tile fallback --- __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) { - int blk_idx = (kt * 32) / QK_Q4_0x4x2; - int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; - bool upper = (sub_blk >= 4); - int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + if (is_q4) { + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; + bool upper = (sub_blk >= 4); + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); HVX_Vector v_off = v_scat_base; // reset to column 0 - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); @@ -585,7 +603,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( } (void) *(volatile HVX_Vector *)(tile_base); } - ++t; + ++t; ++kt; } // Drain HVX scatter write buffer: a vmem load on the same HW thread retires @@ -653,9 +671,13 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( // --- End x4x2 dequantizers --- // requires external HMX lock -static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales, +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, int n_row_tiles, int n_col_tiles, int n_dot_tiles) { - hmx_set_output_scales(scales); + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); for (int r = 0; r < n_row_tiles; ++r) { for (int c = 0; c < n_col_tiles; ++c) { @@ -665,16 +687,55 @@ static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; for (int k = 0; k < n_dot_tiles; ++k) { - int offset = k * HMX_FP16_TILE_N_ELMS; - hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; - hmx_consume_accumulator_fp16(out_tile); + Q6_mxmem_AR_after_hf(out_tile, 0); } } } +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + int n_row_tiles, + int n_col_tiles, + int n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// --- End async HMX matmul job --- + static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; @@ -832,12 +893,13 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), - params->m, params->n, - &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } @@ -1006,13 +1068,15 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/ 256, - /*per_n=*/ 3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/ sizeof(__fp16), // O - m, n, - &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/sizeof(__fp16), // O + m, n, + /*m_block_cost=*/(size_t) n, + /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } @@ -1157,6 +1221,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, int k, int n, int w_type); +#define FALLBACK_TO_STANDARD 1 + int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { @@ -1169,9 +1235,12 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // for large m, k (e.g. prefill FFN Down), use out-stationary version if (m >= 128 && k > n && n > 1024) { - FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)", - m, k, n, weight_type, (k + 511) / 512); - return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + if (rc != FALLBACK_TO_STANDARD) { + return rc; // 0 success, -1 error + } + FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); + // fall through to standard path } size_t row_stride = get_x4x2_row_stride(weight_type, k); @@ -1197,9 +1266,10 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, - m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + // Quantized weight: dequant ~1.5x more expensive per element than activation load. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", __func__, m, k, n, use_pipeline, vtcm_budget); return -1; @@ -1256,9 +1326,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - if (!use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { // transfer activation matrix chunk into VTCM size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); @@ -1318,20 +1387,22 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_STOP(output_store); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } else { // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // stage B and D (dequantize and store) are expected to be on the critical path + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). // A --> B: vtcm_qweight, 1 buffer // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers // C --> D: vtcm_output0/vtcm_output1, 2 buffers - // - // LD ||A3| | B3 || - // MM || C2 || - // ST || D1 | || + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); @@ -1352,31 +1423,34 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); } - // prologue: B0, A1, C0, B1 + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) { - // B0 + // B0: wait for DMA, dequant weight chunk 0 dma_queue_pop(ctx->dma[0]); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - // A1 + // A1: issue DMA for weight chunk 1 const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); if (1 < n_chunk_cnt) { const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - // C0 - core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // B1 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } } - // main loop + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) for (int i = 0; i < n_chunk_cnt; ++i) { const size_t nc = i * n_chunk_n_cols; const size_t nc_p1 = nc + 1 * n_chunk_n_cols; @@ -1386,36 +1460,41 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // issue A_{i+2} + // issue A_{i+2}: DMA push (non-blocking) if (i + 2 < n_chunk_cnt) { const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); } - // wait for HMX (C_{i}) -- C_{i} is done - - // result of B_{i+1} (input of C_{i+1}) should be ready now + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); - // issue C_{i+1} + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. if (i + 1 < n_chunk_cnt) { - core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); } - // compute D_{i} + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) float *output_chunk = dst + (mr * n + nc); transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - // wait for DMA (A_{i+2}), compute B_{i+2} + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) if (i + 2 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); } } } - } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + hmx_queue_suspend(ctx->hmx_queue); + } TIMER_STOP(total); @@ -1434,10 +1513,13 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } // C += AB -void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile, +void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); - hmx_set_output_scales(col_scales); + Q6_bias_mxmem2_A((void *)col_scales); for (int i = 0; i < n_row_tiles; ++i) { for (int j = 0; j < n_col_tiles; ++j) { @@ -1448,15 +1530,17 @@ void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS; if (!zero_init) { - hmx_load_tile_pair_fp16(accum_tile, eye_tile); + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } for (int k = 0; k < n_dot_tiles; ++k) { - int offset = k * HMX_FP16_TILE_N_ELMS; - hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; } - - hmx_consume_accumulator_fp16(accum_tile); + Q6_mxmem_AR_after_hf(accum_tile, 0); } } } @@ -1540,12 +1624,41 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict const size_t vtcm_budget = ctx->vtcm_size; - const size_t M_BLOCK_SIZE = 512; - const size_t N_BLOCK_SIZE = 512; - const size_t K_BLOCK_SIZE = 512; + const size_t K_BLOCK_SIZE = 1024; - // Compute precise buffer sizes + // Fallback: if k doesn't need K-blocking, out-stationary has no advantage + const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; + if (k_iters_check <= 1) { + FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); + return FALLBACK_TO_STANDARD; + } + + // Dynamic M,N search via hmx_compute_chunks const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) + + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) + const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) + + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) + const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) + // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; + // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin + const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; + const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment + + size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; + // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. + // From profiling: wt_dequant per element ≈ 1.5× activation load per element. + // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). + // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). + const size_t m_block_cost = (size_t) n * 3; + const size_t n_block_cost = (size_t) m * 2; + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + &N_BLOCK_SIZE, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + // Compute precise buffer sizes from searched M,N and fixed K const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); @@ -1554,7 +1667,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n); + FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, + vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); return -1; } @@ -1568,8 +1682,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, + M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); // initialize eye tile (32x32 identity matrix) { diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.c b/ggml/src/ggml-hexagon/htp/hmx-queue.c new file mode 100644 index 00000000000..5b1d83a0cbf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.c @@ -0,0 +1,158 @@ +#pragma clang diagnostic ignored "-Wunused-function" + +#include +#include +#include + +#include +#include + +#include + +#include "hmx-queue.h" + +#define QURT_LOWEST_PRIO (254) + +static inline void hmx_lock(struct hmx_queue *q) +{ + if (!q->hmx_locked) { + HAP_compute_res_hmx_lock(q->hap_rctx); + q->hmx_locked = true; + } +} + +static inline void hmx_unlock(struct hmx_queue *q) +{ + if (q->hmx_locked) { + HAP_compute_res_hmx_unlock(q->hap_rctx); + q->hmx_locked = false; + } +} + +static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) { + unsigned int ir = atomic_load(&q->idx_read); + + while (ir != atomic_load(&q->idx_write)) { + struct hmx_queue_desc *d = &q->desc[ir]; + if (!d->done) { + FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data); + + enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func; + switch (sig) { + case HMX_QUEUE_NOOP: /* noop */; break; + case HMX_QUEUE_KILL: *killed = true; break; + case HMX_QUEUE_SUSPEND: hmx_unlock(q); break; + default: + hmx_lock(q); + d->func(d->data); + break; + } + + atomic_fetch_add(&d->done, 1); + } + + ir = (ir + 1) & q->idx_mask; + atomic_store(&q->idx_read, ir); + } +} + +static void hmx_queue_thread(void * arg) { + struct hmx_queue * q = (struct hmx_queue *) arg; + + FARF(HIGH, "hmx-queue-thread: started"); + + bool killed = false; + + unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT; + unsigned int prev_seqn = 0; + while (!killed) { + unsigned int seqn = atomic_load(&q->seqn); + if (seqn == prev_seqn) { + if (--poll_cnt) { hex_pause(); continue; } + FARF(HIGH, "hmx-queue-thread: sleeping"); + qurt_futex_wait(&q->seqn, prev_seqn); + continue; + } + prev_seqn = seqn; + poll_cnt = HMX_QUEUE_POLL_COUNT; + + FARF(HIGH, "hmx-queue-thread: new work"); + + hmx_queue_process(q, &killed); + } + + FARF(HIGH, "hmx-queue-thread: stopped"); +} + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) { + capacity = hex_ceil_pow2(capacity); + + struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue)); + if (q == NULL) { + FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__); + return NULL; + } + memset(q, 0, sizeof(struct hmx_queue)); + q->capacity = capacity; + q->idx_mask = capacity - 1; + q->hap_rctx = hap_rctx; + + q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc)); + if (!q->desc) { + FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n"); + return NULL; + } + memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc)); + + const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE; + q->stack = (unsigned char *) memalign(64, stack_size); + if (!q->stack) { + FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size); + return NULL; + } + memset(q->stack, 0, stack_size); + + // Match caller thread priority (same pattern as worker-pool.c). + int prio = qurt_thread_get_priority(qurt_thread_get_id()); + if (prio < 1) { + prio = 1; + } + if (prio > QURT_LOWEST_PRIO) { + prio = QURT_LOWEST_PRIO; + } + + qurt_thread_attr_t attr; + qurt_thread_attr_init(&attr); + qurt_thread_attr_set_stack_addr(&attr, q->stack); + qurt_thread_attr_set_stack_size(&attr, stack_size); + qurt_thread_attr_set_priority(&attr, prio); + qurt_thread_attr_set_name(&attr, "hmx-queue"); + + int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q); + if (err) { + FARF(ERROR, "hmx-worker: thread create failed (%d)", err); + return NULL; + } + + FARF(HIGH, "hmx-queue: capacity %u\n", capacity); + + return q; +} + +void hmx_queue_delete(struct hmx_queue * q) { + if (!q) { + return; + } + + // Tell the worker to exit. + hmx_queue_flush(q); + hmx_queue_signal(q, HMX_QUEUE_KILL); + hmx_queue_flush(q); + + int status; + qurt_thread_join(q->thread, &status); + + free(q->desc); + free(q->stack); + free(q); +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.h b/ggml/src/ggml-hexagon/htp/hmx-queue.h new file mode 100644 index 00000000000..0d48c280f52 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.h @@ -0,0 +1,134 @@ +#ifndef HMX_QUEUE_H +#define HMX_QUEUE_H + +#include +#include +#include + +#include +#include +#include +#include + +#include "hex-utils.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024) +#define HMX_QUEUE_POLL_COUNT 2000 + +typedef void (*hmx_queue_func)(void *); + +// Dummy funcs used as signals +enum hmx_queue_signal { + HMX_QUEUE_NOOP = 0, // aka NULL + HMX_QUEUE_SUSPEND, + HMX_QUEUE_KILL +}; + +struct hmx_queue_desc { + hmx_queue_func func; + void * data; + atomic_uint done; +}; + +struct hmx_queue { + struct hmx_queue_desc * desc; + atomic_uint idx_write; // updated by producer (push) + atomic_uint idx_read; // updated by consumer (process) + unsigned int idx_pop; // updated by producer (pop) + uint32_t idx_mask; + uint32_t capacity; + + atomic_uint seqn; // incremented for all pushes, used with futex + qurt_thread_t thread; + void * stack; + uint32_t hap_rctx; + bool hmx_locked; +}; + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx); +void hmx_queue_delete(struct hmx_queue * q); + +static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) { + struct hmx_queue_desc d = { func, data }; + return d; +} + +static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) { + unsigned int ir = atomic_load(&q->idx_read); + unsigned int iw = q->idx_write; + + if (((iw + 1) & q->idx_mask) == ir) { + FARF(HIGH, "hmx-queue-push: queue is full\n"); + return false; + } + + atomic_store(&d.done, 0); + + FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data); + + q->desc[iw] = d; + atomic_store(&q->idx_write, (iw + 1) & q->idx_mask); + // wake up our thread + atomic_fetch_add(&q->seqn, 1); + qurt_futex_wake(&q->seqn, 1); + + return true; +} + +static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) { + return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL)); +} + +static inline bool hmx_queue_empty(struct hmx_queue * q) { + return q->idx_pop == q->idx_write; +} + +static inline uint32_t hmx_queue_depth(struct hmx_queue * q) { + return (q->idx_read - q->idx_read) & q->idx_mask; +} + +static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) { + return q->capacity; +} + +static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) { + unsigned int ip = q->idx_pop; + unsigned int iw = q->idx_write; + + struct hmx_queue_desc rd = { NULL, NULL }; + if (ip == iw) { + return rd; + } + + // Wait for desc to complete + struct hmx_queue_desc * d = &q->desc[ip]; + while (!atomic_load(&d->done)) { + FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip); + hex_pause(); + } + + rd = *d; + q->idx_pop = (ip + 1) & q->idx_mask; + + FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data); + return rd; +} + +static inline void hmx_queue_flush(struct hmx_queue * q) { + while (hmx_queue_pop(q).func != NULL) ; +} + +static inline void hmx_queue_suspend(struct hmx_queue *q) { + hmx_queue_signal(q, HMX_QUEUE_SUSPEND); + hmx_queue_flush(q); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HMX_QUEUE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index aacfbcda287..af04619cebb 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -14,10 +14,6 @@ #define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) -static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) { - asm volatile("bias = mxmem2(%0)" :: "r"(scales)); -} - // Initialise aligned 256-byte area with scale vector + zero padding. static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { HVX_Vector *pv = (HVX_Vector *)out_scales; @@ -25,58 +21,6 @@ static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vecto *pv = Q6_V_vzero(); } -// Load multiple contiguous tiles with :deep streaming. -// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt]. -// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank -// boundary, otherwise the mxmem instruction will raise a precise bus error. -// Callers must ensure their VTCM layout satisfies this constraint. -static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles, - const __fp16 *col_tiles, - size_t n_tiles) { - size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1; - asm volatile( - "{ activation.hf = mxmem(%0, %1):deep\n" - "weight.hf = mxmem(%2, %3) }\n" - :: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit) - : "memory"); -} - -// Load a single activation+weight tile pair (no :deep streaming). -// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula -// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047. -// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation -// places a tile near a 4 MB bank boundary, the oversized region crosses it and -// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly -// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047). -static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile, - const __fp16 *wt_tile) { - asm volatile( - "{ activation.hf = mxmem(%0, %1)\n" - "weight.hf = mxmem(%2, %3) }\n" - :: "r"(act_tile), "r"(2047), - "r"(wt_tile), "r"(2047) - : "memory"); -} - -static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) { - // Use the combined convert-and-store instruction (matches the reference - // Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence - // "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter. - asm volatile( - "mxmem(%0, %1):after.hf = acc\n" - :: "r"(out), "r"(0) - : "memory"); -} - -// Compute inner product of two vectors of tiles and store result. -static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out, - const __fp16 *row_tiles, - const __fp16 *col_tiles, - size_t n_tiles) { - hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles); - hmx_consume_accumulator_fp16(out); -} - // --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 4c36a6ea0c2..8b5e47adef8 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -2,6 +2,7 @@ #define HTP_CTX_H #include "hex-dma.h" +#include "hmx-queue.h" #include "htp-ops.h" #include "worker-pool.h" @@ -30,6 +31,8 @@ struct htp_spad { uint32_t size_per_thread; // size per thread }; +struct htp_context; + // Context while processing an Op // TODO: fold this into the main context struct htp_ops_context { @@ -72,6 +75,10 @@ struct htp_context { atomic_bool vtcm_needs_release; struct htp_ops_context octx; + +#ifdef HTP_HAS_HMX + struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap +#endif }; int op_matmul(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 44a6ab4f737..fa84b674cd2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -91,7 +91,12 @@ enum htp_op_code { #define HTP_OP_MAX_BUFS 8 #define HTP_OP_MAX_REQS 256 #define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) + +#if __HVX_ARCH__ < 75 +#define HTP_OP_MAX_VMEM (3167538380u) +#else #define HTP_OP_MAX_VMEM (3221225472u) +#endif enum htp_tensor_flags { HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index db05ab40d28..ed6026e762a 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -116,9 +116,14 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { } static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) { +#if __HVX_ARCH__ >= 81 + HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0); + HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1); +#else const HVX_Vector zero = Q6_V_vzero(); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); +#endif return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)); } diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 8b347039428..d71c97ed292 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -18,8 +18,9 @@ #include #include -#include "hex-dma.h" #include "hex-utils.h" +#include "hex-dma.h" +#include "hmx-queue.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -324,6 +325,14 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que #ifdef HTP_HAS_HMX ctx->hmx_enabled = use_hmx; + ctx->hmx_queue = NULL; + if (use_hmx) { + ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); + if (!ctx->hmx_queue) { + FARF(ERROR, "hmx-queue-create failed"); + ctx->hmx_enabled = false; + } + } FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); #endif @@ -389,7 +398,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { } #ifdef HTP_HAS_HMX - ctx->hmx_enabled = 0; + if (ctx->hmx_queue) { + hmx_queue_delete(ctx->hmx_queue); + ctx->hmx_queue = NULL; + } + ctx->hmx_enabled = false; #endif vtcm_free(ctx); From 86d94cd95bb043772f6153d0add5bf6a204e066d Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 15 Apr 2026 14:45:16 +0200 Subject: [PATCH 136/249] docs: more extensive RoPE documentation [no ci] (llama/21953) * more extensive ggml_rope documentation * add more docs * nits --- ggml/include/ggml.h | 56 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 11d3e8a8167..703e3783136 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1773,8 +1773,32 @@ extern "C" { int n_dims, int mode); - // custom RoPE + // RoPE operations with extended options + // a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token] + // b is an int32 vector with size n_token // c is freq factors (e.g. phi3-128k), (optional) + // mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi + // + // pseudo-code for computing theta: + // for i in [0, n_dims/2): + // theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims); + // theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c + // theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here + // + // other params are used by YaRN RoPE scaling, these default values will disable YaRN: + // freq_scale = 1.0f + // ext_factor = 0.0f + // attn_factor = 1.0f + // beta_fast = 0.0f + // beta_slow = 0.0f + // + // example: + // (marking: c = cos, s = sin, 0 = unrotated) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000] + // GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs] + // GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000] + // GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss] GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1790,6 +1814,36 @@ extern "C" { float beta_fast, float beta_slow); + // multi-dimensional RoPE, for Qwen-VL and similar vision models + // mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX + // sections specify how many dimensions to rotate in each section: + // section length is equivalent to number of cos/sin pairs, NOT the number of dims + // (i.e. sum of 4 sections are expected to be n_dims/2) + // last sections can be 0, means ignored + // all other options are identical to ggml_rope_ext + // + // important note: + // - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION + // if you need normal ordering, there are 2 methods: + // (1) split the tensor manually using ggml_view + // (2) permute the weight upon conversion + // - for VISION, n_dims must be head_size/2 + // + // example M-RoPE: + // given sections = [t=4, y=2, x=2, 0] + // given a single head with size = 18 --> [000000000000000000] + // GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering) + // GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering) + // note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section + // in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section + // + // example vision RoPE: + // given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx] + // other values of n_dims are untested and is undefined behavior + // note: unlike MROPE, the theta for each dim is computed differently for each section + // in other words, idx used for theta: [0123] for y section, then [0123] for x section GGML_API struct ggml_tensor * ggml_rope_multi( struct ggml_context * ctx, struct ggml_tensor * a, From 182db04cb2e6ce68b5bfa17571222b179f3840ae Mon Sep 17 00:00:00 2001 From: Valeriy Dubov Date: Wed, 15 Apr 2026 16:44:02 +0300 Subject: [PATCH 137/249] rpc : add native RDMA transport for RPC backend (RoCEv2) (llama/20590) --- ggml/include/ggml-rpc.h | 6 +- ggml/src/ggml-rpc/CMakeLists.txt | 23 ++ ggml/src/ggml-rpc/ggml-rpc.cpp | 610 +++++++++++++++++++++++++++++-- 3 files changed, 601 insertions(+), 38 deletions(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1c11495b66e..6fcf5a43393 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -6,9 +6,9 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 6 -#define RPC_PROTO_PATCH_VERSION 1 +#define RPC_PROTO_MAJOR_VERSION 4 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 #ifdef __cplusplus static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index f5acb8ec2cb..8671ce5ceaf 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -7,3 +7,26 @@ ggml_add_backend_library(ggml-rpc if (WIN32) target_link_libraries(ggml-rpc PRIVATE ws2_32) endif() + +# RDMA auto-detection (Linux only, requires libibverbs) +if (NOT WIN32 AND NOT APPLE) + find_library(IBVERBS_LIB ibverbs) + if (IBVERBS_LIB) + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON) + else() + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF) + endif() +else() + set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE) +endif() + +if (GGML_RPC_RDMA) + if (NOT IBVERBS_LIB) + find_library(IBVERBS_LIB ibverbs REQUIRED) + endif() + target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA) + target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB}) + message(STATUS " RDMA transport enabled (auto-detected)") +else() + message(STATUS " RDMA transport disabled") +endif() diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 61bfcc5a675..017ef0af360 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -3,7 +3,9 @@ #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include #include +#include #include #include #include @@ -31,6 +33,14 @@ #include #include +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -49,17 +59,116 @@ typedef int sockfd_t; #endif // cross-platform socket + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; +#endif // GGML_RPC_RDMA + +// conn_caps size for transport-agnostic capability exchange +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +// conn_caps RDMA layout helper +#ifdef GGML_RPC_RDMA +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); +#endif // GGML_RPC_RDMA + +// Forward declarations for transport function pointers +struct socket_t; +static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); +static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); + struct socket_t { sockfd_t fd; + bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; + bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; +#ifdef GGML_RPC_RDMA + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA socket_t(sockfd_t fd) : fd(fd) {} ~socket_t() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); #ifdef _WIN32 - closesocket(this->fd); + if (fd != INVALID_SOCKET) closesocket(this->fd); #else - close(this->fd); + if (fd >= 0) close(this->fd); #endif } + + // Advertise local transport capabilities into conn_caps. + // May probe RDMA and store the probe on this socket for update_caps. + void get_caps(uint8_t * caps); + + // Activate transport upgrade based on remote conn_caps using the probe + // previously stored by get_caps. + void update_caps(const uint8_t * remote_caps); }; // macro for nicer error messages on server crash @@ -115,10 +224,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_req { + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + struct rpc_msg_hello_rsp { uint8_t major; uint8_t minor; uint8_t patch; + uint8_t padding; + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; }; struct rpc_msg_device_count_rsp { @@ -414,27 +529,414 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } -static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { - if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { +// TCP transport implementations (for function-pointer dispatch) + +static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { + return send_data(sock->fd, data, size); +} + +static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { + return recv_data(sock->fd, data, size); +} + +// RDMA transport (performance-optimized, auto-negotiated) + +#ifdef GGML_RPC_RDMA + +static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); +static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); + +static inline bool tcp_peer_closed(int fd) { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed(tcp_fd)) { + return false; + } + } + } +} + +static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) { + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc, tcp_fd)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + + +static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) { + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { + return rdma_send(sock->rdma.get(), data, size, sock->fd); +} + +static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { + return rdma_recv(sock->rdma.get(), data, size, sock->fd); +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +static std::optional rdma_build_target_gid(sockfd_t tcp_fd) { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(tcp_fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(tcp_fd); + if (!target_gid) { + return nullptr; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return nullptr; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + out->path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return nullptr; + + out->ib_port = ib_port; + out->gid_idx = gid_idx; + + // unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(), + // so each failure path is a plain `return nullptr;`. + auto c = std::make_unique(); + c->ctx = ibctx; + + c->pd = ibv_alloc_pd(ibctx); + if (!c->pd) return nullptr; + + c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!c->scq || !c->rcq) return nullptr; + + ibv_qp_init_attr qia = {}; + qia.send_cq = c->scq; + qia.recv_cq = c->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + c->qp = ibv_create_qp(c->pd, &qia); + if (!c->qp) return nullptr; + c->max_inline = qia.cap.max_inline_data; + + c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + c->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!c->tx_buf || !c->rx_buf) return nullptr; + + c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!c->tx_mr || !c->rx_mr) return nullptr; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr; + + out->qpn = c->qp->qp_num; + out->psn = c->qp->qp_num & 0xffffff; + memcpy(out->gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, out->qpn, c->max_inline); + return c.release(); +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, + uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = local->ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!c->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = local->path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = local->gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = local->ib_port; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = local->psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); + return true; +} + +#endif // GGML_RPC_RDMA + +// --------------------------------------------------------------------------- +// socket_t transport capability methods +// --------------------------------------------------------------------------- + +void socket_t::get_caps(uint8_t * caps) { + memset(caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + rdma.reset(rdma_probe(fd, &rdma_local)); + if (rdma) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(caps, &rc, sizeof(rc)); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) { + fn_send = rdma_send_impl; + fn_recv = rdma_recv_impl; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + +// unified transport dispatch (via function pointers) + +static bool send_data(socket_t * sock, const void * data, size_t size) { + return sock->fn_send(sock, data, size); +} + +static bool recv_data(socket_t * sock, void * data, size_t size) { + return sock->fn_recv(sock, data, size); +} + +static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { + if (!send_data(sock, &msg_size, sizeof(msg_size))) { return false; } - return send_data(sockfd, msg, msg_size); + return send_data(sock, msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { +static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!recv_data(sock, &size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sockfd, msg, msg_size); + return recv_data(sock, msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, std::vector & input) { +static bool recv_msg(socket_t * sock, std::vector & input) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!recv_data(sock, &size, sizeof(size))) { return false; } try { @@ -443,7 +945,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sockfd, input.data(), size); + return recv_data(sock, input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -452,7 +954,11 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int return false; } host = endpoint.substr(0, pos); - port = std::stoi(endpoint.substr(pos + 1)); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } return true; } @@ -460,13 +966,13 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // No response static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + if (!send_data(sock.get(), &input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input, input_size)) { + if (!send_data(sock.get(), input, input_size)) { return false; } return true; @@ -478,16 +984,14 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } - // TODO: currently the output_size is always known, do we need support for commands with variable output size? - // even if we do, we can skip sending output_size from the server for commands with known output size uint64_t out_size; - if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { + if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock->fd, output, output_size)) { + if (!recv_data(sock.get(), output, output_size)) { return false; } return true; @@ -495,17 +999,25 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation -static bool check_server_version(const std::shared_ptr & sock) { - rpc_msg_hello_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); +// Performs HELLO handshake with transport auto-negotiation. +// Advertises local capabilities via conn_caps; if the server responds with +// matching capabilities, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr & sock) { + rpc_msg_hello_req request = {}; + rpc_msg_hello_rsp response = {}; + + sock->get_caps(request.conn_caps); + + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { - GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", + response.major, response.minor, response.patch); return false; } - if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { - GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); - } + + sock->update_caps(response.conn_caps); return true; } @@ -527,6 +1039,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); return nullptr; } + #ifdef _WIN32 if (!initialized) { WSADATA wsaData; @@ -543,10 +1056,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } - if (!check_server_version(sock)) { + if (!negotiate_hello(sock)) { return nullptr; } - LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); sockets[endpoint] = sock; return sock; } @@ -1597,25 +2110,46 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - sockfd_t sockfd) { + socket_t * sockfd) { rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; } - // the first command sent by the client must be HELLO if (cmd != RPC_CMD_HELLO) { GGML_LOG_ERROR("Expected HELLO command, update client\n"); return; } - if (!recv_msg(sockfd, nullptr, 0)) { + + // Read input_size and validate protocol version + uint64_t hello_input_size; + if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { return; } - rpc_msg_hello_rsp response; - server.hello(response); - if (!send_msg(sockfd, &response, sizeof(response))) { + + if (hello_input_size != sizeof(rpc_msg_hello_req)) { + GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n", + (size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION); + return; + } + + rpc_msg_hello_req req = {}; + if (!recv_data(sockfd, &req, sizeof(req))) { return; } + + rpc_msg_hello_rsp rsp = {}; + server.hello(rsp); + + // Advertise server transport capabilities based on client's caps + sockfd->get_caps(rsp.conn_caps); + + if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + return; + } + + // Activate transport upgrade using client's caps + sockfd->update_caps(req.conn_caps); while (true) { if (!recv_data(sockfd, &cmd, 1)) { break; @@ -1884,6 +2418,12 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir if (!parse_endpoint(endpoint, host, port)) { return; } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif // GGML_RPC_RDMA #ifdef _WIN32 { WSADATA wsaData; @@ -1907,7 +2447,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd); + rpc_serve_client(backends, cache_dir, client_socket.get()); printf("Client connection closed\n"); fflush(stdout); } From 7e57b20d533b2854738e005db9c4c8aa510d67bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 15 Apr 2026 15:58:40 +0200 Subject: [PATCH 138/249] CUDA: manage NCCL communicators in context (llama/21891) * CUDA: manage NCCL communicators in context * add check that all backends are CUDA * remove unused vector, limit init to > 1 GPUs * fix warnings * fix cuda device, cache allreduce --- ggml/include/ggml-backend.h | 7 +- ggml/src/ggml-backend-meta.cpp | 37 +++++++--- ggml/src/ggml-cuda/common.cuh | 4 -- ggml/src/ggml-cuda/ggml-cuda.cu | 118 +++++++++++++++++++++++--------- 4 files changed, 119 insertions(+), 47 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4a8f6d4287d..d0c7e5a1be0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -202,8 +202,11 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // AllReduce operation for tensor parallelism (meta backend) - typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // Context management and operations for faster communication between backends, used for tensor parallelism (meta backend) + typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends); + typedef void (*ggml_backend_comm_free_t)(void * comm_ctx); + typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors); + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 0a8eea4e945..1ee3eeb4d96 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1419,22 +1419,48 @@ struct ggml_backend_meta_context { size_t max_tmp_size = 0; size_t max_subgraphs = 0; + void * comm_ctx = nullptr; + ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); name = "Meta("; + std::vector simple_backends; backend_configs.reserve(n_devs); + simple_backends.reserve(n_devs); for (size_t i = 0; i < n_devs; i++) { ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); if (i > 0) { name += ","; } name += ggml_backend_dev_name(simple_dev); - backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params)); + simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); + backend_configs.emplace_back(simple_backends.back()); } name += ")"; + + if (n_devs > 1) { + ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init"); + if (comm_init != nullptr) { + comm_ctx = comm_init(simple_backends.data(), simple_backends.size()); + } + } + if (comm_ctx != nullptr) { + comm_allreduce = (ggml_backend_comm_allreduce_tensor_t) + ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg( + ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor"); + GGML_ASSERT(comm_allreduce != nullptr); + } } ~ggml_backend_meta_context() { + if (comm_ctx != nullptr) { + ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free"); + GGML_ASSERT(comm_free != nullptr); + comm_free(comm_ctx); + } for (auto & bc : backend_configs) { ggml_backend_free(bc.backend); } @@ -1845,20 +1871,15 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, if (n_backends > 1 && i < n_subgraphs - 1) { bool backend_allreduce_success = false; - ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address( - ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor"); - if (allreduce_tensor) { - std::vector backends; - backends.reserve(n_backends); + if (backend_ctx->comm_ctx) { std::vector nodes; nodes.reserve(n_backends); for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - backends.push_back(bcj.backend); ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); } - backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends); + backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data()); } if (!backend_allreduce_success) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8a4246223b5..2e5eaff9bf4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1092,10 +1092,6 @@ struct ggml_cuda_device_info { cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; std::array default_tensor_split = {}; - -#ifdef GGML_USE_NCCL - ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; -#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3113de017f0..5d81befec32 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -338,14 +338,6 @@ static ggml_cuda_device_info ggml_cuda_init() { } } -#ifdef GGML_USE_NCCL - int dev_ids[GGML_CUDA_MAX_DEVICES]; - for (int id = 0; id < info.device_count; ++id) { - dev_ids[id] = id; - } - NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); -#endif // GGML_USE_NCCL - return info; } @@ -1125,7 +1117,69 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; -bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) { +#ifdef GGML_USE_NCCL +struct ggml_backend_cuda_comm_context { + std::vector backends; + std::vector comms; + + ~ggml_backend_cuda_comm_context() { + for (ncclComm_t comm : comms) { + NCCL_CHECK(ncclCommDestroy(comm)); + } + } +}; +#endif // GGML_USE_NCCL + +static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { +#ifdef GGML_USE_NCCL + if (comm_ctx_v == nullptr) { + return; + } + ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; + delete comm_ctx; +#else + GGML_UNUSED(comm_ctx_v); +#endif // GGML_USE_NCCL +} + +static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { +#ifdef GGML_USE_NCCL + for (size_t i = 0; i < n_backends; i++) { + if (!ggml_backend_is_cuda(backends[i])) { + return nullptr; + } + } + ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context; + std::vector dev_ids; + ret->backends.reserve(n_backends); + dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; i++) { + ret->backends.push_back(backends[i]); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + dev_ids.push_back(cuda_ctx->device); + } + + ret->comms.resize(n_backends); + NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data())); + return ret; +#else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + static bool warning_printed = false; + if (!warning_printed) { + GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); + warning_printed = true; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + GGML_UNUSED_VARS(backends, n_backends); + return nullptr; +#endif // GGML_USE_NCCL +} + +static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { #ifdef GGML_USE_NCCL const int64_t ne = ggml_nelements(tensors[0]); // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 @@ -1133,21 +1187,24 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t if (ne == 0) { return true; } + + GGML_ASSERT(comm_ctx_v != nullptr); + ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; + const size_t n_backends = comm_ctx->backends.size(); + for (size_t i = 0; i < n_backends; ++i) { GGML_ASSERT(tensors[i] != nullptr); GGML_ASSERT(ggml_nelements(tensors[i]) == ne); GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); } - const ggml_cuda_device_info info = ggml_cuda_info(); - // For small tensors, simply reduce them as FP32. // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); } NCCL_CHECK(ncclGroupEnd()); @@ -1160,44 +1217,33 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t ggml_cuda_pool_alloc tmp[GGML_CUDA_MAX_DEVICES]; for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; tmp[i].pool = &cuda_ctx->pool(); tmp[i].alloc(ne); - ggml_cuda_set_device(i); + ggml_cuda_set_device(cuda_ctx->device); to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); CUDA_CHECK(cudaGetLastError()); } NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); } NCCL_CHECK(ncclGroupEnd()); for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; - ggml_cuda_set_device(i); + ggml_cuda_set_device(cuda_ctx->device); to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream()); CUDA_CHECK(cudaGetLastError()); } return true; #else - // If NCCL is installed it is used by default for optimal performance. - // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. - // RCCL is disabled by default, users are explicitly opting in. - // Therefore print no warning for RCCL. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - static bool warning_printed = false; - if (!warning_printed) { - GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); - warning_printed = true; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - GGML_UNUSED_VARS(backends, tensors, n_backends); + GGML_UNUSED_VARS(comm_ctx_v, tensors); return false; #endif // GGML_USE_NCCL } @@ -5220,8 +5266,14 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); - if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) { - return (void *)ggml_backend_cuda_allreduce_tensor; + if (strcmp(name, "ggml_backend_comm_init") == 0) { + return (void *)ggml_backend_cuda_comm_init; + } + if (strcmp(name, "ggml_backend_comm_free") == 0) { + return (void *)ggml_backend_cuda_comm_free; + } + if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_comm_allreduce_tensor; } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; From 9638e29657e7c547212284a2b473335c31a86a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 15 Apr 2026 16:01:46 +0200 Subject: [PATCH 139/249] CUDA: require explicit opt-in for P2P access (llama/21910) --- ggml/src/ggml-cuda/ggml-cuda.cu | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5d81befec32..c17db3875ad 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -324,16 +324,18 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); - for (int id = 0; id < info.device_count; ++id) { - ggml_cuda_set_device(id); - for (int id_other = 0; id_other < info.device_count; ++id_other) { - if (id == id_other) { - continue; - } - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + if (getenv("GGML_CUDA_P2P") != nullptr) { + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } } } } From 2a785c596944da4cc67d15c8600f606b5021e7bb Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 15 Apr 2026 09:14:40 -0700 Subject: [PATCH 140/249] ggml-webgpu: Fix dequantization helpers to not pass in pointers (llama/21872) * Fix dequantization helpers to not pass in pointers * Increase XIELU precision --- .../wgsl-shaders/common_decls.tmpl | 73 +++++++++----- .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 90 +++++++++--------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 90 +++++++++--------- .../wgsl-shaders/mul_mat_decls.tmpl | 94 +++++++++---------- .../ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl | 2 + .../wgsl-shaders/mul_mat_reg_tile.wgsl | 2 + .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 3 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 48 +++++----- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 13 +-- 9 files changed, 223 insertions(+), 192 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 0d3501c34a2..62fe72ee3b1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_u16_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - return (word >> shift) & 0xFFFF; +#ifdef DECLARE_BYTE_LOADERS_SRC +fn load_u16_at_src(byte_offset: u32) -> u32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; } -fn load_u32_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4; - let shift = (byte_offset & 0x3) * 8; - let lo = buf[word_idx]; - let hi = buf[word_idx + 1]; - let shifted = (lo >> shift) | (hi << (32 - shift)); - return select(shifted, lo, shift == 0); +fn load_u32_at_src(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src[word_idx]; + let hi = src[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); } -fn load_f16_at( - buf: ptr, read_write>, - byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_u16_at(buf, byte_offset)); +fn load_f16_at_src(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src(byte_offset)); return f16(packed[0]); } -fn load_f16_as_f32_at( - buf: ptr, read_write>, - byte_offset: u32) -> f32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - let d_bits = (word >> shift) & 0xFFFF; +fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; return unpack2x16float(d_bits)[0]; } #endif +#ifdef DECLARE_BYTE_LOADERS_SRC0 +fn load_u16_at_src0(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_u32_at_src0(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src0[word_idx]; + let hi = src0[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src0(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src0(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif +#endif + #ifdef Q4_1_T diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 3c8b84c9ac3..1415798fa6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC #include "common_decls.tmpl" + #ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; @@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); - let qh_packed = load_u32_at(&src, block_byte_base + 2); + let d = load_f16_as_f32_at_src(block_byte_base); + let qh_packed = load_u32_at_src(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); @@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 8u; j++) { let q_byte_offset = block_byte_base + 2u + j * 4u; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 108); + let d = load_f16_as_f32_at_src(block_byte_base + 108); // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + scale_vals[0] = load_u32_at_src(block_byte_base + 96); + scale_vals[1] = load_u32_at_src(block_byte_base + 100); + scale_vals[2] = load_u32_at_src(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; @@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 208); + let d = load_f16_as_f32_at_src(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16u; i++) { - qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); + qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src, aux0_offset); - let aux1 = load_u32_at(&src, aux1_offset); + let aux0 = load_u32_at_src(aux0_offset); + let aux1 = load_u32_at_src(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var scale_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); for (var ib: u32 = 0; ib < 32; ib += 4) { @@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + qh_vals[0] = load_u32_at_src(block_byte_base + 66); + qh_vals[1] = load_u32_at_src(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + scale_vals[0] = load_u32_at_src(block_byte_base + 74); + scale_vals[1] = load_u32_at_src(block_byte_base + 78); for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src, sc_sign_offset); + let sc_sign = load_u32_at_src(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qh_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src, block_byte_base + 106); + var scale_vals = load_u32_at_src(block_byte_base + 106); for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index fdabaf09b2e..fcbefdeb802 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef FLOAT const BLOCK_SIZE = 1u; @@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; - let qh_packed = load_u32_at(&src0, block_byte_base + 2); + let qh_packed = load_u32_at_src0(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 108); + let d = load_f16_as_f32_at_src0(block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes @@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + scale_vals[0] = load_u32_at_src0(block_byte_base + 96); + scale_vals[1] = load_u32_at_src0(block_byte_base + 100); + scale_vals[2] = load_u32_at_src0(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + let d = load_f16_as_f32_at_src0(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); + qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src0, aux0_offset); - let aux1 = load_u32_at(&src0, aux1_offset); + let aux0 = load_u32_at_src0(aux0_offset); + let aux1 = load_u32_at_src0(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var scale_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sum = 0.0; @@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + qh_vals[0] = load_u32_at_src0(block_byte_base + 66); + qh_vals[1] = load_u32_at_src0(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + scale_vals[0] = load_u32_at_src0(block_byte_base + 74); + scale_vals[1] = load_u32_at_src0(block_byte_base + 78); var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { @@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src0, sc_sign_offset); + let sc_sign = load_u32_at_src0(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qh_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var scale_vals = load_u32_at_src0(block_byte_base + 106); var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { @@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 56a76a6e6c4..5a323818260 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 80u); - let dmin = load_f16_at(&src0, block_byte_base + 82u); + let d = load_f16_at_src0(block_byte_base + 80u); + let dmin = load_f16_at_src0(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 108u); + let d = load_f16_at_src0(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,8 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) @@ -520,14 +520,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -537,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -571,8 +571,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // The original loop processes elements in groups of 64 @@ -597,14 +597,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -614,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -666,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); + let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); + let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); + let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -687,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); + let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_f16_at(&src0, block_byte_base + 208u); + let d = load_f16_at_src0(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl index 5f763a6400a..91039ff2546 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index ee37e6d249c..98bbdeb83ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 4151ce430b0..d86a72ce6e0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -3,7 +3,9 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" // TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. @@ -196,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6f6bcaf7940..9f7b3e32eca 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef VEC #define VEC_SIZE 4 @@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = f32(load_f16_at(&src0, block_byte_base + 2u)); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at(&src0, bbase + 208u)); + let d = f32(load_f16_at_src0(bbase + 208u)); - let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); - let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); - let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); + let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8c334817ccd..b8f1bca1284 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3) { -9.010913, 9.010913))); #endif #ifdef XIELU + let val = f32(src[params.offset_src + src_idx]); let res = - select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) - - src[params.offset_src + src_idx]) * - TYPE(params.alpha_n) + - TYPE(params.beta) * src[params.offset_src + src_idx], - TYPE(params.alpha_p) * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] + - TYPE(params.beta) * src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + TYPE(select( + ((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val, + params.alpha_p * val * val + params.beta * val, + val > 0.0)); #endif #ifdef SOFTPLUS let src_f32 = f32(src[params.offset_src + src_idx]); From c6d1fbf31f3f8c611772e3a6bb3d3b35ac5f01eb Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 15 Apr 2026 09:38:38 -0700 Subject: [PATCH 141/249] cuda: Q1_0 initial backend (llama/21629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [cuda] initial Q1_0 backend * remove unused code, fix AMD MMA guard * attempt to support dp4a * Apply suggestions from code review Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 7 ++ ggml/src/ggml-cuda/convert.cu | 10 ++ ggml/src/ggml-cuda/dequantize.cuh | 22 +++++ ggml/src/ggml-cuda/getrows.cu | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 2 + ggml/src/ggml-cuda/mmq.cu | 4 + ggml/src/ggml-cuda/mmq.cuh | 93 +++++++++++++++++++ ggml/src/ggml-cuda/mmvq.cu | 8 ++ .../template-instances/generate_cu_files.py | 1 + .../template-instances/mmq-instance-q1_0.cu | 5 + ggml/src/ggml-cuda/vecdotq.cuh | 48 ++++++++++ 11 files changed, 204 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2e5eaff9bf4..ad30ecd8fa5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -924,6 +924,13 @@ struct ggml_cuda_type_traits { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_0; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 79ccfe568a2..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,27 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..e99cba63d34 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c17db3875ad..790f53cead7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4831,6 +4831,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4868,6 +4869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 27b4145ac9a..3f01ff5bfb0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 18911141472..28b662df925 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; @@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3290,6 +3375,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template struct mmq_type_traits; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 07b10167bc4..8f55cace1a1 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; @@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; @@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type( const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 40d51f93fa4..841059c15b5 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40b2b41e7e8..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -669,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { From 7fe6b8e171d23fe12847dbf42309d46144ea6407 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 15 Apr 2026 19:04:51 +0200 Subject: [PATCH 142/249] vulkan: optimize im2col (llama/21713) * vulkan: improve im2col memory write layout * cap workgroups * minimal device tuning * use vendor_id instead of subgroup size --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++- .../ggml-vulkan/vulkan-shaders/im2col.comp | 96 +++++++------------ 2 files changed, 46 insertions(+), 63 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b2a54bd85d0..702a249d754 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; uint32_t KW; uint32_t KH; - uint32_t pelements; + uint32_t OH_batch; uint32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; @@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; - elements = { OW * KW * KH, OH, batch * IC }; + const uint32_t CHW = IC * KH * KW; + // Cap X workgroups to limit concurrent IC channel reads. + // The shader loops over X to cover the full CHW dimension. + // AMD prefers a lower limit + const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u; + const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW)); + elements = { x_elements, OW, OH * batch }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; @@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const uint32_t pelements = OW * KW * KH; const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, - pelements, + OH * batch, IC * KH * KW, s0, s1, p0, p1, d0, d1, batch * IC }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 674f91e5ed2..ba4c2103f0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -13,7 +13,7 @@ layout (push_constant) uniform parameter uint IW; uint IH; uint OW; uint OH; uint KW; uint KH; - uint pelements; + uint OH_batch; uint CHW; int s0; int s1; int p0; int p1; @@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void im2col(const uint y, const uint z) { - const uint gidx = gl_GlobalInvocationID.x; +void im2col(const uint ow, const uint z_idx) { + const uint oh = z_idx % p.OH; + const uint batch_idx = z_idx / p.OH; - const uint oh = y; - const uint batch = z / p.IC; - const uint ic = z % p.IC; + const uint gidx = gl_LocalInvocationID.x; + const uint src_batch = batch_idx * p.batch_offset; + const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW; - const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * p.KH; + const uint KHKW = p.KH * p.KW; - const uint base_linear_idx = gidx * NUM_ITER; + uint wg_x = gl_WorkGroupID.x; + do { + const uint wg_offset = wg_x * 512; - uint current_kx = base_linear_idx / ksize; - const uint rem = base_linear_idx - (current_kx * ksize); - uint current_ky = rem / p.OW; - uint current_ix = rem % p.OW; + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { + const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; - A_TYPE values[NUM_ITER]; - BDA_OFFSET_T offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - - const uint linear_idx = base_linear_idx + idx; - - if (linear_idx >= p.pelements) { - continue; - } - - const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; - const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - - offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; - - if ((iih < p.IH) && (iiw < p.IW)) { - values[idx] = data_a[src_base + iih * p.IW + iiw]; - } - - if (++current_ix == p.OW) { - current_ix = 0; - if (++current_ky == p.KH) { - current_ky = 0; - current_kx++; + if (chw_idx >= p.CHW) { + return; } - } - } - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + const uint ic = chw_idx / KHKW; + const uint rem = chw_idx - ic * KHKW; + const uint ky = rem / p.KW; + const uint kx = rem - ky * p.KW; - const uint linear_idx = base_linear_idx + idx; + const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - if (linear_idx >= p.pelements) { - continue; - } + A_TYPE val = A_TYPE(0); + if (iih < p.IH && iiw < p.IW) { + val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + } #if BDA - D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); - dst_addr.d = D_TYPE(values[idx]); + D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); + out_ptr.d = D_TYPE(val); #else - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[dst_row + chw_idx] = D_TYPE(val); #endif - } + } + + wg_x += gl_NumWorkGroups.x; + } while (wg_x * 512 < p.CHW); } void main() { - uint y = gl_GlobalInvocationID.y; - while (y < p.OH) { + uint ow = gl_GlobalInvocationID.y; + while (ow < p.OW) { uint z = gl_GlobalInvocationID.z; - while (z < p.batch_IC) { - im2col(y, z); + while (z < p.OH_batch) { + im2col(ow, z); z += gl_NumWorkGroups.z; } - y += gl_NumWorkGroups.y; + ow += gl_NumWorkGroups.y; } } From f62bb133207f47e9975dfb511b119a304f23622d Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Thu, 16 Apr 2026 01:34:05 -0400 Subject: [PATCH 143/249] Fix Q8_0 reorder: garbage on 2nd prompt + crash on full VRAM (llama/21638) * [SYCL] Fix Q8_0 reorder: add missing dequantize path for GEMM The Q8_0 reorder optimization (#21527) was missing a reorder-aware dequantizer for the GEMM code path used during prompt processing. After token generation reordered Q8_0 weights (via DMMV/MMVQ), the next prompt processing pass would read them with the standard dequantizer, producing garbage output. Add dequantize_block_q8_0_reorder() and wire it into both ggml_get_to_fp16_sycl() and ggml_get_to_fp32_sycl(), matching the pattern already used by Q4_0, Q4_K, and Q6_K. Fixes #21589 AI (Claude) was used to assist with root cause investigation and writing the kernel code. All code was human-reviewed and tested on real hardware. * SYCL: fix reorder crash when device memory is full The reorder optimization allocates a temporary buffer the full size of the weight tensor on the device. When VRAM is nearly full (large models on a single GPU), this allocation fails and the subsequent memcpy crashes on a NULL pointer. Fix: try device allocation first, fall back to host memory if device memory is full. The reorder kernel still works correctly reading from host memory over PCIe. This is slower for the one-time reorder (~21 t/s vs ~38 t/s on Intel Arc Pro B70), but the optimization is preserved for all subsequent inference. If both device and host allocation fail, skip the reorder and fall back to the unoptimized kernel path. Also fixes a bug where opt_for_reorder() marked tensors as reordered even when the reorder was skipped due to allocation failure. This caused DMMV/MMVQ kernels to read the original AoS data as if it were SoA, producing garbage output or NaN results. Tested on Intel Arc Pro B70 (32GB) with Q8_0, Q4_K_M models. Coding was AI-assisted (Claude), reviewed and tested on hardware by a human. Fixes #20478 * SYCL: add RAII temp buffer class + macro guard for host fallback Replace sycl_ext_malloc_with_fallback/sycl_ext_free_fallback free functions with sycl_reorder_temp_buffer RAII class. The host_fallback bool is now a private member, and cleanup happens automatically at scope exit. Add GGML_SYCL_HOST_MEM_FALLBACK cmake option (default ON) to guard the host memory fallback code path. Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+); users on older kernels can set -DGGML_SYCL_HOST_MEM_FALLBACK=OFF to disable it. Addresses arthw's review on PR #21638. Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: document GGML_SYCL_HOST_MEM_FALLBACK build option in SYCL.md Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: add reorder-aware DMMV dequantizers for Q4_K and Q6_K Q4_K and Q6_K had reorder support for MMVQ and GEMM paths but not DMMV. When the DMMV path encountered reordered data it would abort. Add DMMV kernels that read from the SOA reorder layout for both types. Same math as the non-reorder versions, different memory access pattern. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-sycl/CMakeLists.txt | 5 + ggml/src/ggml-sycl/convert.cpp | 33 ++- ggml/src/ggml-sycl/dequantize.hpp | 28 +++ ggml/src/ggml-sycl/dmmv.cpp | 321 +++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 106 +++++++--- 6 files changed, 465 insertions(+), 29 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8454eecde6e..6b65ecd6e5c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -254,6 +254,7 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 7b07b227874..8e589fa238d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -154,6 +154,11 @@ if (GGML_SYCL_GRAPH) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() +if (GGML_SYCL_HOST_MEM_FALLBACK) + message(STATUS "find GGML_SYCL_HOST_MEM_FALLBACK") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_HOST_MEM_FALLBACK) +endif() + if (GGML_SYCL_DEVICE_ARCH) target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f12419426ae..f3c521b45f6 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -151,6 +151,25 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int } +template +static void dequantize_row_q8_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK8_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % QK8_0 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q8_0_reorder(vx, y, k, item_ct1); + }); + +} + template static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -614,7 +633,12 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl; case GGML_TYPE_Q8_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: @@ -683,7 +707,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl; case GGML_TYPE_Q8_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 68c3db30613..19fa88680d6 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -239,6 +239,34 @@ static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * } +// Dequantize Q8_0 from reorder layout: [all qs (k bytes)][all d values] +// Each thread handles one block of QK8_0 elements. +template +static void dequantize_block_q8_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t k, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK8_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK8_0; + + auto qs = (const int8_t*)vx + lane_ib * QK8_0; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK8_0; ++l) { + y_ptr[l] = d * qs[l]; + } + +} + template static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 1c8b6f3771f..5577bf73b28 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -615,6 +615,162 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q4_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * QK_K/2] [scales: nb * K_SCALE_SIZE] [dm: nb * sizeof(half2)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * scales_base = qs_base + (size_t)nb * (QK_K / 2); + const sycl::half2 * dm_base = (const sycl::half2 *)(scales_base + (size_t)nb * K_SCALE_SIZE); + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const sycl::half2 dm_val = dm_base[bi]; + const float dall = dm_val[0]; + const float dmin = dm_val[1]; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + sycl::float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4]; + s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - + dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const uint8_t * q = qs_base + bi * (QK_K / 2) + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const sycl::half2 dm_val = dm_base[bi]; + const float d = (float)dm_val[0]; + const float m = (float)dm_val[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:7: The total declared local variable size in device function dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register @@ -864,6 +1020,129 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa } } +static void dequantize_mul_mat_vec_q6_k_reorder(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [ql: nb * QK_K/2] [qh: nb * QK_K/4] [scales: nb * QK_K/16] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * ql_base = (const uint8_t *)vx; + const uint8_t * qh_base = ql_base + (size_t)nb * (QK_K / 2); + const int8_t * scales_base = (const int8_t *)(qh_base + (size_t)nb * (QK_K / 4)); + const sycl::half * d_base = (const sycl::half *)((const uint8_t *)scales_base + (size_t)nb * (QK_K / 16)); + +#if QK_K == 256 + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + ql_offset; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + qh_offset; + const int8_t * s = scales_base + bi * (QK_K / 16) + s_offset; + + const float d = d_base[bi]; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + step; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + step; + const int8_t * s = scales_base + bi * (QK_K / 16); + + const float d = d_base[bi]; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -1167,6 +1446,38 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q4_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q6_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q6_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + void ggml_sycl_op_dequantize_mul_mat_vec( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -1235,8 +1546,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - // reorder is currently not supported for dmmv - GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + dequantize_mul_mat_vec_q4_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } else { dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } @@ -1245,7 +1555,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q6_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ea79d2538c1..c02a41ad862 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3348,9 +3348,55 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { sycl::free(ptr, *stream); } -static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, +// RAII wrapper for temporary reorder buffers with optional host memory fallback. +// When device allocation fails and GGML_SYCL_HOST_MEM_FALLBACK is enabled, +// falls back to host memory so the reorder kernel can still run (over PCIe). +// Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+). +struct sycl_reorder_temp_buffer { + void * ptr = nullptr; + dpct::queue_ptr stream; + + sycl_reorder_temp_buffer(dpct::queue_ptr stream, size_t size) : stream(stream) { + ptr = sycl_ext_malloc_device(stream, size); +#ifdef GGML_SYCL_HOST_MEM_FALLBACK + if (!ptr) { + ptr = sycl::malloc_host(size, *stream); + if (ptr) { + host_fallback = true; + GGML_LOG_WARN("%s: device alloc of %zu bytes failed, using host memory fallback\n", __func__, size); + } + } +#endif + } + + ~sycl_reorder_temp_buffer() { + if (!ptr) { + return; + } + if (host_fallback) { + sycl::free(ptr, *stream); + } else { + sycl_ext_free(stream, ptr); + } + } + + explicit operator bool() const { return ptr != nullptr; } + + sycl_reorder_temp_buffer(const sycl_reorder_temp_buffer &) = delete; + sycl_reorder_temp_buffer & operator=(const sycl_reorder_temp_buffer &) = delete; + +private: + bool host_fallback = false; +}; + +static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3379,12 +3425,17 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, +static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3413,16 +3464,21 @@ static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q4_K) == 0); GGML_ASSERT(offset % sizeof(block_q4_K) == 0); const int nblocks = size / sizeof(block_q4_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3451,16 +3507,21 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q6_K) == 0); GGML_ASSERT(offset % sizeof(block_q6_K) == 0); const int nblocks = size / sizeof(block_q6_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3499,10 +3560,10 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { +static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; @@ -3510,20 +3571,16 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { switch (src0->type) { case GGML_TYPE_Q4_0: - reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); - break; + return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q8_0: - reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); - break; + return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q4_K: - reorder_qw_q4_k(data_device, size, 0, stream); - break; + return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q6_K: - reorder_qw_q6_k(data_device, size, 0, stream); - break; + return reorder_qw_q6_k(data_device, size, 0, stream); default: GGML_ABORT("reorder_qw() called with unsupported type"); - break; + return false; } } @@ -3563,8 +3620,9 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * break; } - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + if (reorder_qw(src0, ctx->stream())) { + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + } } From 092330b474ed34f80ed854ae7b64034a94a6f79a Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 16 Apr 2026 01:12:19 -0700 Subject: [PATCH 144/249] ggml-webgpu: compute pass batching and removing profiling overhead (llama/21873) * Update register tiling matmul to use f32 accumulation * fix profiling code * Fix register tiling matmul for chrome, i'm blaming dawn * Update batch tuning value for iOS * compile fix * Fix use of new load function * Move to a single query set for GPU profiling * Move to batching compute passes when not profiling * Refactor build_multi * remove iOS throttling now that we're batching compute passes --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 799 ++++++++++++--------------- 1 file changed, 348 insertions(+), 451 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa3fe06d5a9..01637e2ddab 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -73,8 +73,8 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 -# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +# define WEBGPU_MAX_PROFILE_QUERY_COUNT 4096u +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES (WEBGPU_MAX_PROFILE_QUERY_COUNT * sizeof(uint64_t)) #endif /* Constants */ @@ -159,78 +159,20 @@ struct webgpu_param_arena { ~webgpu_param_arena() { this->cleanup(); } }; -#ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - - free.push_back({ host_buf, dev_buf, ts_query_set }); - } - } - - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); - } - free.clear(); - } - - ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } -}; -#endif - struct webgpu_encoded_op { uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; + std::vector pipeline_names; #endif }; +struct webgpu_dispatch_desc { + webgpu_pipeline pipeline; + std::vector params; + std::vector bind_group_entries; + std::pair workgroups = { 1, 1 }; +}; + struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroup_matrix = false; @@ -256,7 +198,7 @@ struct webgpu_global_context_struct { webgpu_capabilities capabilities; // Shared buffer to move data from device to host wgpu::Buffer get_tensor_staging_buf; - // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. + // Global mutex for get_tensor std::recursive_mutex mutex; wgpu::Buffer memset_params_buf; @@ -272,8 +214,6 @@ struct webgpu_global_context_struct { #ifdef GGML_WEBGPU_GPU_PROFILE // Profiling: per-shader GPU time in ms std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; #endif #ifdef GGML_WEBGPU_DEBUG @@ -312,11 +252,45 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_param_arena param_arena; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; + wgpu::CommandEncoder active_command_encoder; + wgpu::ComputePassEncoder active_compute_pass; size_t memset_bytes_per_thread; + +#ifdef GGML_WEBGPU_GPU_PROFILE + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; +#endif + + ~webgpu_context_struct() { +#ifdef GGML_WEBGPU_GPU_PROFILE + if (this->profile_timestamp_host_buf) { + this->profile_timestamp_host_buf.Destroy(); + this->profile_timestamp_host_buf = nullptr; + } + if (this->profile_timestamp_dev_buf) { + this->profile_timestamp_dev_buf.Destroy(); + this->profile_timestamp_dev_buf = nullptr; + } + if (this->profile_timestamp_query_set) { + this->profile_timestamp_query_set.Destroy(); + this->profile_timestamp_query_set = nullptr; + } +#endif + if (this->set_rows_host_error_buf) { + this->set_rows_host_error_buf.Destroy(); + this->set_rows_host_error_buf = nullptr; + } + if (this->set_rows_dev_error_buf) { + this->set_rows_dev_error_buf.Destroy(); + this->set_rows_dev_error_buf = nullptr; + } + } }; typedef std::shared_ptr webgpu_context; @@ -399,24 +373,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures) { - if (futures.empty()) { - return; - } - - constexpr size_t max_futures_per_wait = 64; - - while (!futures.empty()) { - ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX); - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); - } -} -#endif - template static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, T callback_status, @@ -436,22 +392,8 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, } } -#ifdef __EMSCRIPTEN__ -EM_JS(int, ggml_webgpu_is_ios_browser, (), { - const ua = navigator.userAgent; - return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; -}); -#endif - // TODO: these next two functions may want tuning across different platforms and workloads, static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { -#ifdef __EMSCRIPTEN__ - // iOS has very strict limits on the number of in-flight GPU commands, - // so we need to throttle to avoid failures. - if (ggml_webgpu_is_ios_browser()) { - return 1; - } -#endif return UINT32_MAX; } @@ -524,118 +466,77 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, - const std::vector & commands, - std::vector & futures) { - for (const auto & command : commands) { - auto label = command.pipeline_name; - auto ts_bufs = command.timestamp_query_bufs; - - wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); - } else { - const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); - // WebGPU timestamps are in ns; convert to ms - double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; - ctx->shader_gpu_time_ms[label] += elapsed_ms; - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); - }); - futures.push_back({ f }); - } -} -#endif - -static webgpu_encoded_op ggml_backend_webgpu_build_multi( - webgpu_global_context & ctx, - webgpu_param_arena & param_arena, - wgpu::CommandEncoder & encoder, - const std::vector & pipelines, - const std::vector> & params_list, - const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list) { - GGML_ASSERT(pipelines.size() == params_list.size()); - GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); - GGML_ASSERT(pipelines.size() == workgroups_list.size()); - +static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & ctx, + const std::vector & dispatches) { webgpu_encoded_op result = {}; std::vector bind_groups; std::vector param_offsets; - result.num_kernels = pipelines.size(); + result.num_kernels = dispatches.size(); - for (size_t i = 0; i < pipelines.size(); i++) { - const size_t param_size = params_list[i].size() * sizeof(uint32_t); - const size_t param_offset = param_arena.alloc_slot(param_size); + for (size_t i = 0; i < dispatches.size(); i++) { + const webgpu_dispatch_desc & dispatch = dispatches[i]; + const size_t param_size = dispatch.params.size() * sizeof(uint32_t); + const size_t param_offset = ctx->param_arena.alloc_slot(param_size); - std::vector entries = bind_group_entries_list[i]; + std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); entries.push_back({ .binding = params_binding_num, - .buffer = param_arena.buffer, + .buffer = ctx->param_arena.buffer, .offset = param_offset, - .size = param_arena.slot_size }); + .size = ctx->param_arena.slot_size }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = entries.size(); bind_group_desc.entries = entries.data(); - bind_group_desc.label = pipelines[i].name.c_str(); - bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); + bind_group_desc.label = dispatch.pipeline.name.c_str(); + bind_groups.push_back(ctx->global_ctx->device.CreateBindGroup(&bind_group_desc)); param_offsets.push_back(param_offset); } for (size_t i = 0; i < param_offsets.size(); i++) { - ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(), - params_list[i].size() * sizeof(uint32_t)); + ctx->global_ctx->queue.WriteBuffer(ctx->param_arena.buffer, param_offsets[i], dispatches[i].params.data(), + dispatches[i].params.size() * sizeof(uint32_t)); } + #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); + for (size_t i = 0; i < dispatches.size(); i++) { + GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, + .beginningOfPassWriteIndex = query_begin, + .endOfPassWriteIndex = query_end }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + result.pipeline_names.push_back(dispatches[i].pipeline.name); } - - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); #else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + for (size_t i = 0; i < dispatches.size(); i++) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); } - pass.End(); - -#ifdef GGML_WEBGPU_GPU_PROFILE - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); - result.timestamp_query_bufs = ts_bufs; - result.pipeline_name = pipelines.front().name; #endif + return result; } -static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_param_arena & param_arena, - wgpu::CommandEncoder & encoder, +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_context & ctx, webgpu_pipeline & pipeline, std::vector params, std::vector bind_group_entries, uint32_t wg_x, uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder, - { - pipeline - }, - { std::move(params) }, { std::move(bind_group_entries) }, - { { wg_x, wg_y } }); + return ggml_backend_webgpu_build_multi( + ctx, { + { pipeline, std::move(params), std::move(bind_group_entries), { wg_x, wg_y } }, + }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -784,10 +685,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 return flags; } -static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, @@ -825,14 +723,13 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -891,13 +788,10 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; @@ -949,14 +843,13 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1011,14 +904,13 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1068,18 +960,17 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * src3, - ggml_tensor * src4, - ggml_tensor * src5, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1154,14 +1045,13 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -1224,7 +1114,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1235,11 +1125,10 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1291,14 +1180,13 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1437,15 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1457,10 +1344,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); @@ -1520,10 +1404,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); - pipelines.push_back(gather_pipeline); - params_list.push_back(std::move(gather_params)); - entries_list.push_back(std::move(gather_entries)); - workgroups_list.push_back({ gather_wg_x, gather_wg_y }); + dispatches.push_back({ + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + }); // params for mul_mat_id.wgsl std::vector main_params = { @@ -1588,24 +1471,21 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - pipelines.push_back(main_pipeline); - params_list.push_back(std::move(main_params)); - entries_list.push_back(std::move(main_entries)); - workgroups_list.push_back({ wg_x, wg_y }); + dispatches.push_back({ + main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } + }); - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #ifndef __EMSCRIPTEN__ -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -1897,40 +1777,33 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const uint64_t split_wg_total = (uint64_t) wg_x * nwg; GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; if (use_blk) { - pipelines.push_back(blk_pipeline); - params_list.push_back(std::move(blk_params)); - entries_list.push_back(std::move(blk_entries)); - workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + dispatches.push_back({ + blk_pipeline, + std::move(blk_params), + std::move(blk_entries), + { blk_nblk0, blk_nblk1 * blk_batch_count } + }); } - pipelines.push_back(pipeline); - params_list.push_back(std::move(split_params)); - entries_list.push_back(std::move(split_entries)); - workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); if (use_vec_reduce) { - pipelines.push_back(reduce_pipeline); - params_list.push_back(std::move(reduce_params)); - entries_list.push_back(std::move(reduce_entries)); - workgroups_list.push_back({ (uint32_t) nrows, 1u }); + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } #endif // __EMSCRIPTEN__ -static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); @@ -2005,14 +1878,13 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2108,14 +1980,13 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -2165,13 +2036,10 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -2210,13 +2078,10 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -2256,16 +2121,14 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, }; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, - ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } -static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2362,14 +2225,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2428,13 +2290,10 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2482,15 +2341,14 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2566,14 +2424,10 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } -static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2595,13 +2449,10 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2659,10 +2510,7 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1]; const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2]; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst; const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); @@ -2686,14 +2534,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } }; - pipelines.push_back(argsort_pipeline); - params_list.push_back(std::move(init_params)); - entries_list.push_back(std::move(init_entries)); - workgroups_list.push_back({ wg_x_init, wg_y_init }); + dispatches.push_back({ + argsort_pipeline, std::move(init_params), std::move(init_entries), { wg_x_init, wg_y_init } + }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } bool in_is_tmp = start_in_tmp; @@ -2745,23 +2591,18 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, const uint32_t total_wg_merge = nm * nrows; const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); - workgroups_list.push_back({ wg_x_merge, wg_y_merge }); - pipelines.push_back(argsort_merge_pipeline); - params_list.push_back(std::move(merge_params)); - entries_list.push_back(std::move(merge_entries)); + dispatches.push_back({ + argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } + }); len <<= 1; in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2786,13 +2627,10 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2821,13 +2659,11 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * node) { +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { return std::nullopt; } @@ -2850,20 +2686,20 @@ static std::optional ggml_webgpu_encode_node(webgpu_context return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - return ggml_webgpu_cpy(ctx, encoder, src0, node); + return ggml_webgpu_cpy(ctx, src0, node); case GGML_OP_SET: - return ggml_webgpu_set(ctx, encoder, src0, src1, node); + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: - return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node); + return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: - return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); + return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: - return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); + return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_MUL_MAT_ID: - return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ - return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); + return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); #else return std::nullopt; #endif @@ -2871,22 +2707,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node); + return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_CONCAT: - return ggml_webgpu_concat(ctx, encoder, src0, src1, node); + return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_REPEAT: - return ggml_webgpu_repeat(ctx, encoder, src0, node); + return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return ggml_webgpu_row_norm(ctx, encoder, src0, node); + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: - return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: - return ggml_webgpu_glu(ctx, encoder, src0, src1, node); + return ggml_webgpu_glu(ctx, src0, src1, node); case GGML_OP_SCALE: - return ggml_webgpu_scale(ctx, encoder, src0, node); + return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: - return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: case GGML_OP_CLAMP: case GGML_OP_FILL: @@ -2897,32 +2733,80 @@ static std::optional ggml_webgpu_encode_node(webgpu_context case GGML_OP_COS: case GGML_OP_DIAG: case GGML_OP_TRI: - return ggml_webgpu_unary_op(ctx, encoder, src0, node); + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SOLVE_TRI: - return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node); + return ggml_webgpu_solve_tri(ctx, src0, src1, node); case GGML_OP_SSM_CONV: - return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node); + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); case GGML_OP_GATED_DELTA_NET: - return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5], - node); + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: - return ggml_webgpu_pad(ctx, encoder, src0, node); + return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: - return ggml_webgpu_argmax(ctx, encoder, src0, node); + return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k - return ggml_webgpu_argsort(ctx, encoder, src0, node); + return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_CUMSUM: - return ggml_webgpu_cumsum(ctx, encoder, src0, node); + return ggml_webgpu_cumsum(ctx, src0, node); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: - return ggml_webgpu_sum_rows(ctx, encoder, src0, node); + return ggml_webgpu_sum_rows(ctx, src0, node); default: return std::nullopt; } } +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_results(webgpu_context & ctx, + const std::vector & pipeline_names, + uint32_t & num_inflight_batches) { + if (pipeline_names.empty()) { + return; + } + + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.ResolveQuerySet(ctx->profile_timestamp_query_set, 0, ctx->profile_timestamp_query_count, + ctx->profile_timestamp_dev_buf, 0); + encoder.CopyBufferToBuffer(ctx->profile_timestamp_dev_buf, 0, ctx->profile_timestamp_host_buf, 0, + ctx->profile_timestamp_query_count * sizeof(uint64_t)); + + wgpu::CommandBuffer profile_commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, profile_commands, num_inflight_batches); + + const size_t mapped_size = ctx->profile_timestamp_query_count * sizeof(uint64_t); + GGML_ASSERT(ctx->profile_timestamp_query_count == 2 * pipeline_names.size()); + + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->profile_timestamp_host_buf, wgpu::MapMode::Read, 0, + mapped_size); + const uint64_t * ts_data = (const uint64_t *) ctx->profile_timestamp_host_buf.GetConstMappedRange(0, mapped_size); + + for (size_t i = 0; i < pipeline_names.size(); ++i) { + // WebGPU timestamps are in ns; convert to ms. + const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; + ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + } + + ctx->profile_timestamp_host_buf.Unmap(); +} +#endif + +static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, commands, num_inflight_batches); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); +} + static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); @@ -2932,69 +2816,77 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); std::vector commands; + + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + bool batch_compute_passes = true; + #ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; + ctx->profile_timestamp_query_count = 0; + batch_compute_passes = false; + std::vector profile_pipeline_names; #endif - uint32_t num_batched_kernels = 0; - uint32_t num_inflight_batches = 0; - bool contains_set_rows = false; - wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; +#ifdef GGML_WEBGPU_GPU_PROFILE + profile_pipeline_names.insert(profile_pipeline_names.end(), cmd->pipeline_names.begin(), + cmd->pipeline_names.end()); +#endif } if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + } num_batched_kernels = 0; - wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); -#endif + + // reset state for next batch + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } ctx->param_arena.reset(); commands.clear(); - batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); } } - if (!commands.empty()) { - wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + ctx->active_compute_pass = nullptr; + } + + if (num_batched_kernels > 0) { + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); -#endif ctx->param_arena.reset(); commands.clear(); } + ctx->active_command_encoder = nullptr; + +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); +#endif - // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. if (contains_set_rows) { - wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, - ctx->set_rows_host_error_buf.GetSize()); - wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ggml_backend_webgpu_submit_commands(ctx, set_rows_commands, num_inflight_batches); + ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } ggml_backend_webgpu_wait_queue(ctx->global_ctx); - if (contains_set_rows) { - ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, - ctx->set_rows_host_error_buf.GetSize()); - const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - ctx->set_rows_host_error_buf.Unmap(); - } - -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures); -#endif WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3535,14 +3427,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries, used for profiling - ctx->webgpu_global_ctx->timestamp_query_buf_pool.init( - ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - GGML_LOG_INFO( "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "device_desc: %s\n", @@ -3567,6 +3451,19 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_webgpu_create_buffer( + webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_host_buf, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "profile_timestamp_host_buf"); + wgpu::QuerySetDescriptor query_set_desc = {}; + query_set_desc.type = wgpu::QueryType::Timestamp; + query_set_desc.count = WEBGPU_MAX_PROFILE_QUERY_COUNT; + webgpu_ctx->profile_timestamp_query_set = webgpu_ctx->global_ctx->device.CreateQuerySet(&query_set_desc); +#endif + #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, From 07c181b57f6d59027ea6fe3931967993e4f870a6 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer Date: Thu, 16 Apr 2026 13:14:26 +0500 Subject: [PATCH 145/249] ggml : implemented simd_gemm kernel for riscv vector extension (llama/20627) Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/simd-gemm.h | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h index 78d663e593e..4119d04f895 100644 --- a/ggml/src/ggml-cpu/simd-gemm.h +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -109,6 +109,96 @@ static void simd_gemm( C += N; } } +#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic) +// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7 +// Microkernel: C[RM x vl] += A[RM x K] * B[K x N] +template +static inline void rvv_simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, size_t vl) +{ + static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4"); + + vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl); + vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6; + if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl); + if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl); + if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl); + if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl); + if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl); + if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl); + + for (int kk = 0; kk < K; kk++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl); + + acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl); + if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl); + if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl); + if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl); + if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl); + if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl); + if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl); + } + + __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl); + if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl); + if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl); + if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl); + if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl); + if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl); + if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl); +} + +template +static inline void rvv_simd_gemm_dispatch_tail( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, int KN, int remaining_rows) +{ + if constexpr (RM > 0) { + if (remaining_rows == RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + } else { + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); + } + } +} + +static constexpr int GEMM_RM = 7; + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + const int KN = (int)__riscv_vlenb(); + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + int remaining_rows = M - ii; + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); +} #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop From 94d6d0b743206b10a3074ea805c385b18fcd1498 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer Date: Thu, 16 Apr 2026 13:15:15 +0500 Subject: [PATCH 146/249] ggml-cpu: add 128-bit RVV implementation for Quantization Vector Dot (llama/20633) * ggml-cpu: add 128-bit impls for i-quants, ternary quants * ggml-cpu: add 128-bit impls for iq2_xs, iq3_s, iq3_xxs, tq2_0 Co-authored-by: Rehan Qasim * ggml-cpu: refactor; add rvv checks --------- Co-authored-by: taimur-10x Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch/riscv/quants.c | 972 ++++++++++++++++++++++++-- 1 file changed, 902 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index d7e9ba46348..d3278d6489f 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -15,6 +15,12 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -117,7 +123,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in assert(k % QK_K == 0); size_t nb = k / QK_K; -#if defined(__riscv_v_intrinsic) +#if defined __riscv_v_intrinsic block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); @@ -2053,7 +2059,119 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16m1_t qh = __riscv_vle16_v_u16m1(x[i].qh, 8); + + // Calculate ls. + vuint16m1_t temp = __riscv_vsrl_vx_u16m1(qh, 12, 8); + temp = __riscv_vand_vx_u16m1(temp, 7, 8); + vint32m2_t ls = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vwmulu_vx_u32m2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m2(ls, 1, 8); + + // Calculate delta. + vbool16_t mask = __riscv_vmseq_vx_u16m1_b16(__riscv_vand_vx_u16m1(qh, 0x8000, 8), 0, 8); + vint32m2_t delta_neg = __riscv_vmv_v_x_i32m2(-1, 8); + vint32m2_t delta_pos = __riscv_vmv_v_x_i32m2(1, 8); + vint32m2_t delta = __riscv_vmerge_vvm_i32m2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m2_t qs = __riscv_vle8_v_u8m2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m4_t qh_shift = __riscv_vreinterpret_v_u64m4_u16m4(__riscv_vmv_v_x_u64m4(shift, 8)); + vuint16m4_t qh_gather_index = __riscv_vreinterpret_v_i16m4_u16m4( + __riscv_vdiv_vx_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vid_v_u16m4(32)), 4, 32)); + vuint16m4_t qh_ext = __riscv_vlmul_ext_v_u16m2_u16m4(__riscv_vlmul_ext_v_u16m1_u16m2(qh)); + vuint16m4_t qh_index = __riscv_vrgather_vv_u16m4(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m4(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m4(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m4(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m4(qh_index, __riscv_vzext_vf2_u16m4(qs, 32), 32); + vuint16m4_t index = __riscv_vsll_vx_u16m4(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-2 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[0], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 3-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 1); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[64], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-6 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 2); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[128], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 7-8 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 3); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[192], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m2_t lsums = __riscv_vle32_v_i32m2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m2_t bsums_0 = __riscv_vle16_v_i16m2(y[i].bsums, 16); + const vuint32m2_t bsums_i32 = __riscv_vreinterpret_v_u16m2_u32m2(__riscv_vreinterpret_v_i16m2_u16m2(bsums_0)); + const vint16m1_t bsums_i32_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 0, 8)); + const vint16m1_t bsums_i32_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 16, 8)); + const vint32m2_t bsums = __riscv_vwadd_vv_i32m2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m2_t sumi_v = __riscv_vmul_vv_i32m2(ls, lsums, 8); + vint32m2_t sumi1_v = __riscv_vmul_vv_i32m2(__riscv_vmul_vv_i32m2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2153,6 +2271,9 @@ static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2166,7 +2287,174 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m4_t acc1 = __riscv_vmv_v_x_i32m4(0, 16); + vint32m4_t acc2 = __riscv_vmv_v_x_i32m4(0, 16); + + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 8 sub-blocks. + const vuint8mf2_t qh_8 = __riscv_vle8_v_u8mf2(qh, 8); + const vuint16m1_t qh_16_lo = __riscv_vzext_vf2_u16m1(qh_8, 8); + const vuint16m1_t qh_16_hi = __riscv_vsll_vx_u16m1(qh_16_lo, 8, 8); + const vuint16m2_t qhb = __riscv_vzext_vf2_u16m2( + __riscv_vreinterpret_v_u16m1_u8m1(__riscv_vor_vv_u16m1(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m2_t qsb = __riscv_vzext_vf2_u16m2(__riscv_vle8_v_u8m1(&qs[0], 16), 16); + const vuint16m2_t shift = __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00040008, 8)); + vuint16m2_t index = __riscv_vor_vv_u16m2(qsb, __riscv_vand_vx_u16m2(__riscv_vsll_vv_u16m2(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m2(index, 3, 16); + qs += 16; + + // Prepare the deltas. + const vbool8_t mask = __riscv_vmsgtu_vx_u16m2_b8( + __riscv_vand_vv_u16m2(qhb, __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00800008, 8)), 16), 0, 16); + const vint64m8_t delta_pos = __riscv_vmv_v_x_i64m8(0x0101010101010101, 16); + const vint8m8_t delta = __riscv_vreinterpret_v_i64m8_i8m8( + __riscv_vmerge_vxm_i64m8(delta_pos, 0xffffffffffffffff, mask, 16)); + + // Sub-blocks 0-3 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 0), 8))); + + // Calculate the lsums. + // + // Sub-block 0, 1 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 0), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 2, 3 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 1), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 4-7 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 1), 8))); + + // Calculate the lsums. + // + // Sub-block 4, 5 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 2), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 6, 7 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 3), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2193,9 +2481,10 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); - // We process 4 sub-blocks together. + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 for (int ib = 0; ib < QK_K/128; ib++) { - // Load qh for 4 sub-blocks. + // Load qh for 8 sub-blocks. const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); @@ -2203,6 +2492,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); qh += 8; + __asm__ __volatile__("" ::: "memory"); + // Prepare grid indices. const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); @@ -2210,6 +2501,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t index = __riscv_vsll_vx_u16m1(index, 3, 16); qs += 16; + __asm__ __volatile__("" ::: "memory"); + // Load the grid. const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); @@ -2218,9 +2511,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); - const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( - __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 16)); // Load q8 for sub-blocks. const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); @@ -2261,6 +2553,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + + __asm__ __volatile__("" ::: "memory"); } // Reduce and accumulate in `sumf`. @@ -2277,6 +2571,9 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2300,8 +2597,7 @@ static const uint8_t sign_bit_masks_arr[64] = { 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 }; - -static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -2392,7 +2688,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -2513,7 +2809,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined(__riscv_v_intrinsic) +#if defined __riscv_v_intrinsic static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -2549,7 +2845,84 @@ static const int8_t keven_signs_q2xs[1024] = { 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, }; -static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; +#pragma GCC unroll 1 + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + // Loop over 4 subblocks of 64 elements + for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { + + // Load indices. + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 8); + qs += 8; + + // Prepare offsets + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 8), 3, 8); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 8), 3, 8); + + // load values and signs from the lookup tables + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 8); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 8); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 64); + asm volatile("" ::: "memory"); + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 64); + asm volatile("" ::: "memory"); + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 0), zero_vec, 16)); + + int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 1), zero_vec, 16)); + + int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 2), zero_vec, 16)); + + int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 3), zero_vec, 16)); + + const uint8_t scale_byte_1 = scales[0]; + const uint8_t scale_byte_2 = scales[1]; + scales += 2; + + sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); + sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); + sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); + sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + } + + sumf += d * sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2628,6 +3001,9 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2641,7 +3017,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2732,7 +3108,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2833,7 +3209,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } @@ -2843,7 +3219,102 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + vuint8mf2_t v_id_8 = __riscv_vid_v_u8mf2(8); + vuint8m2_t v_id_32 = __riscv_vid_v_u8m2(32); + + // Keeping these in a tight scope to hint they're only needed for the mask computation. + vuint8m2_t v_sign_gather_indices, v_sign_masks; + { + vuint8m2_t v_shifts = __riscv_vand_vx_u8m2(v_id_32, 7, 32); + vuint8m2_t v_one_32 = __riscv_vmv_v_x_u8m2(1, 32); + v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_id_32, 3, 32); + v_sign_masks = __riscv_vsll_vv_u8m2(v_one_32, v_shifts, 32); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Grid lookup + vuint8m2_t v_grid_u8; + { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 8); + qs += 8; + + uint8_t qh_val = *qh++; + vuint8mf2_t v_qh_val = __riscv_vmv_v_x_u8mf2(qh_val, 8); + v_qh_val = __riscv_vsrl_vv_u8mf2(v_qh_val, v_id_8, 8); + v_qh_val = __riscv_vand_vx_u8mf2(v_qh_val, 1, 8); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 8); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 8); + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_val, 8); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 8); + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 8); + + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 8); + v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + } + __asm__ volatile ("" ::: "memory"); + + //Sign application and dot product + int32_t s_val; + { + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 4); + signs += 4; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 32); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + s_val = __riscv_vmv_x_s_i32m1_i32( + __riscv_vwredsum_vs_i16m4_i32m1(v_dot, v_zero, 32)); + } + __asm__ volatile ("" ::: "memory"); + { + uint8_t sc_byte = scales[ib >> 1]; + int sc_val = (ib & 1) ? (sc_byte >> 4) : (sc_byte & 0xF); + sc_val = sc_val * 2 + 1; + sum_block += (float)(s_val * sc_val); + } + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); @@ -2942,6 +3413,9 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2955,7 +3429,100 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m2_t v_shifts = __riscv_vle32_v_u32m2(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m2_t v_gather_idx = __riscv_vle32_v_u32m2(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + // Process 64 weights per loop + for (int ib = 0; ib < QK_K / 64; ++ib) { + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + vuint8m1_t v_q3_idx_u8 = __riscv_vle8_v_u8m1(q3_indices, 16); + q3_indices += 16; + + vuint16m2_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m2(v_q3_idx_u8, 4, 16); + + vuint32m4_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m4(grid32, v_q3_idx_u16, 16); + + vint8m4_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u32m4_u8m4(v_q3_magnitudes_u32)); + + vuint32m2_t v_aux = __riscv_vle32_v_u32m2(aux32, 2); + + vuint32m2_t v_aux_expanded = __riscv_vrgather_vv_u32m2(v_aux, v_gather_idx, 8); + + vuint32m2_t v_s_vals_raw = __riscv_vand_vx_u32m2( + __riscv_vsrl_vv_u32m2(v_aux_expanded, v_shifts, 8), 127, 8); + + vuint16m1_t sign_indices_byte_offset = __riscv_vsll_vx_u16m1( + __riscv_vncvt_x_x_w_u16m1(v_s_vals_raw, 8), 3, 8); + + vuint64m4_t v_s_vals_u64 = __riscv_vluxei16_v_u64m4(signs64, sign_indices_byte_offset, 8); + + vint8m4_t v_s_vals = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u64m4_u8m4(v_s_vals_u64)); + + vint8m4_t v_q3_signed = __riscv_vmul_vv_i8m4(v_q3_magnitudes, v_s_vals, 64); + asm volatile("" ::: "memory"); + vint8m4_t v_q8 = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t v_dot = __riscv_vwmul_vv_i16m8(v_q8, v_q3_signed, 64); + + asm volatile("" ::: "memory"); + + vint16m4_t v_dot_1 = __riscv_vget_v_i16m8_i16m4(v_dot, 0); + vint16m4_t v_dot_2 = __riscv_vget_v_i16m8_i16m4(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -3052,6 +3619,9 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3065,7 +3635,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3095,12 +3665,14 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_ vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); // Unpack the weight blocks. - vuint8m2_t iq4bits1; - iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16)); - iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16)); - vuint8m2_t iq4bits2; - iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16)); - iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16)); + vuint8m2_t iq4bits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16) + ); + vuint8m2_t iq4bits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16) + ); // Gather values from the lookup table. vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32); @@ -3118,7 +3690,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_ *s = sumf; } -static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3182,7 +3754,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } @@ -3192,7 +3764,73 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + // We process 2 sub-blocks together. + int sumi1 = 0, sumi2 = 0; + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load the packed weights. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 32); + iq4 += 32; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 32); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 32); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vcreate_v_u8m1_u8m4( + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 0), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 2), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 1), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 3), 16) + ); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 64); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3207,16 +3845,15 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); float sumf = 0; - int acc[4]; // Indices for re-ordering IQ4 data. - uint64_t index[16] = { + uint16_t index[16] = { 0, 1, 8, 9, 2, 3, 10, 11, 4, 5,12, 13, 6, 7, 14, 15, }; - vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16); + vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 16); for (int ibl = 0; ibl < nb; ++ibl) { const int8_t * q8 = y[ibl].qs; @@ -3225,30 +3862,33 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + #pragma GCC unroll 1 for (int ib = 0; ib < QK_K / 128; ++ib) { // Weights and activations. vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64); - vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); iq4 += 64; - q8 += 128; // Unpack the weight blocks. vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64); vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64); - vuint8m4_t iq4bits; - iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo); - iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi); - vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); + vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128); + __asm__ __volatile__("" ::: "memory"); + // Multiply with activations. + vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128); + q8 += 128; + + __asm__ __volatile__("" ::: "memory"); // Reduce separately. - __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32; int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32; @@ -3256,10 +3896,12 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32; h >>= 8; - sumi1 += acc[0] * ls1; - sumi2 += acc[1] * ls2; - sumi3 += acc[2] * ls3; - sumi4 += acc[3] * ls4; + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + sumi3 += acc2 * ls3; + sumi4 += acc3 * ls4; + + __asm__ __volatile__("" ::: "memory"); } sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4); @@ -3272,6 +3914,9 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3285,7 +3930,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3301,8 +3946,107 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; for (int i = 0; i < nb; i++) { + const uint8_t * GGML_RESTRICT tq = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + // First loop. - vint32m4_t suml1; + vint16m4_t suml1; + { + const int vl = 32; + const vuint8m2_t tqb = __riscv_vle8_v_u8m2(tq, vl); + tq += 32; + + { + const vuint16m4_t tq0 = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(tqb, 3, vl), 8, vl); + const vint16m4_t q80 = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmul_vv_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tq0, 1, vl)), q80, vl); + q8 += 32; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m4_t tqn = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(__riscv_vmul_vx_u8m2(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m4_t q8n = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmacc_vv_i16m4(suml1, __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 32; + } + } + + // Second loop. + vint16m2_t suml2; + { + const int vl = 16; + const vuint8m1_t tqb = __riscv_vle8_v_u8m1(tq, vl); + + { + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tqb, 3, vl), 8, vl); + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + q8 += 16; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m2_t tqn = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m2_t q8n = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmacc_vv_i16m2(suml2, __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 16; + } + } + + // Third loop. + vint16m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + const vuint8m1_t tqb = __riscv_vreinterpret_v_u32m1_u8m1(__riscv_vmv_v_x_u32m1(qh, vl / 4)); + + const vuint8m1_t p = __riscv_vle8_v_u8m1(pow, vl); + + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vv_u8m1(tqb, p, vl), 3, vl), 8, vl); + + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + + suml3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + } + + vint16m2_t sumb = __riscv_vadd_vv_i16m2(__riscv_vget_v_i16m4_i16m2(suml1, 0), __riscv_vget_v_i16m4_i16m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vwredsum_vs_i16m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m2_t suml1; { const int vl = 32; vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); @@ -3325,13 +4069,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); - vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); - vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); - suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); + vint16m2_t sumi0 = __riscv_vadd_vv_i16m2(sum0, sum1, vl); + vint16m2_t sumi1 = __riscv_vadd_vv_i16m2(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m2(sum4, __riscv_vadd_vv_i16m2(sumi0, sumi1, vl), vl); } // Second loop. - vint32m2_t suml2; + vint16m1_t suml2; { const int vl = 16; vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); @@ -3354,13 +4098,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); - vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); - vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); - suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); } // Third loop. - vint32m2_t suml3; + vint16m1_t suml3; { const int vl = 16; @@ -3376,15 +4120,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); - vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); - suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); + suml3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); } - vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); + vint16m1_t sumb = __riscv_vadd_vv_i16m1(__riscv_vget_v_i16m2_i16m1(suml1, 0), __riscv_vget_v_i16m2_i16m1(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m1(sumb, __riscv_vadd_vv_i16m1(suml2, suml3, 16), 16); - vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); } @@ -3395,6 +4137,9 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3408,7 +4153,89 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vl = __riscv_vsetvl_e16m4(32); + vint16m4_t vacc16 = __riscv_vmv_v_x_i16m4(0, vl); + + // Load Raw Packed elements + vl = __riscv_vsetvl_e8m2(32); + vuint8m2_t vx_u8 = __riscv_vle8_v_u8m2(px, vl); + + // Process bits 1:0 + { + // Unpack + vuint8m2_t t0 = __riscv_vand_vx_u8m2(vx_u8, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t0), 1, vl); + vint8m2_t vy = __riscv_vle8_v_i8m2(py0, vl); + // Accumulate + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 3:2 + { + vuint8m2_t t1 = __riscv_vsrl_vx_u8m2(vx_u8, 2, vl); + t1 = __riscv_vand_vx_u8m2(t1, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t1), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py1, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 5:4 + { + vuint8m2_t t2 = __riscv_vsrl_vx_u8m2(vx_u8, 4, vl); + t2 = __riscv_vand_vx_u8m2(t2, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t2), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py2, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 7:6 + { + vuint8m2_t t3 = __riscv_vsrl_vx_u8m2(vx_u8, 6, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t3), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py3, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + vl = __riscv_vsetvl_e16m4(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m4_i32m1(vacc16, vzero32, vl); + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -3483,6 +4310,9 @@ static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3496,7 +4326,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3526,12 +4356,14 @@ static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); // Unpack the weight blocks. - vuint8m2_t mxbits1; - mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16)); - mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16)); - vuint8m2_t mxbits2; - mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16)); - mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16)); + vuint8m2_t mxbits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16) + ); + vuint8m2_t mxbits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16) + ); // Gather values from the lookup table. vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32); @@ -3549,7 +4381,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t *s = sumf; } -static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3613,7 +4445,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } From 655c0750f5a027817fa3038f17232dd3bf717480 Mon Sep 17 00:00:00 2001 From: Kusha Gharahi <3326002+kushagharahi@users.noreply.github.com> Date: Thu, 16 Apr 2026 03:54:37 -0500 Subject: [PATCH 147/249] metal: Implement ROLL op (llama/21946) * nix: support unified apple-sdk * Impl roll op for Metal * Revert "nix: support unified apple-sdk" This reverts commit abfa473360471532c547de8b202c780507924d4b. * update ops.md * update op docs --- ggml/src/ggml-metal/ggml-metal-device.cpp | 17 +++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 23 +++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 57 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 34 ++++++++++++++ 7 files changed, 134 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 8e0836c0beb..07d016d2227 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1819,6 +1819,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ROLL); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_roll_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index de43f819312..b423501358e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -152,6 +152,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index effe666a691..27cb1683518 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1138,6 +1138,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ARGSORT: case GGML_OP_TOP_K: case GGML_OP_ARANGE: + case GGML_OP_ROLL: return true; case GGML_OP_FLASH_ATTN_EXT: // for new head sizes, add checks here diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e7433f2a658..379a8b33a14 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1017,6 +1017,29 @@ typedef struct { int32_t p1; } ggml_metal_kargs_pad_reflect_1d; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t s3; +} ggml_metal_kargs_roll; + typedef struct { uint64_t nb1; int dim; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5b426be103f..e173527909a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -410,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); } break; + case GGML_OP_ROLL: + { + n_fuse = ggml_metal_op_roll(ctx, idx); + } break; case GGML_OP_ARANGE: { n_fuse = ggml_metal_op_arange(ctx, idx); @@ -3945,6 +3949,59 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ggml_get_op_params_i32(op, 0); + const int32_t s1 = ggml_get_op_params_i32(op, 1); + const int32_t s2 = ggml_get_op_params_i32(op, 2); + const int32_t s3 = ggml_get_op_params_i32(op, 3); + + ggml_metal_kargs_roll args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.s2 =*/ s2, + /*.s3 =*/ s3 + }; + + auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 50e3c5c77a1..36c61071b4f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx); int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 445a4deca83..9f38c9d2968 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5247,6 +5247,40 @@ kernel void kernel_upscale_bicubic_f32( } } +kernel void kernel_roll_f32( + constant ggml_metal_kargs_roll & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *) src0; + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + // apply shifts and wrap around + int64_t i00 = i0 - args.s0; + int64_t i01 = i1 - args.s1; + int64_t i02 = i2 - args.s2; + int64_t i03 = i3 - args.s3; + + if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; } + if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; } + if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; } + if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; } + + int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00; + int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0; + + dst_ptr[dst_idx] = src0_ptr[src_idx]; + } +} + kernel void kernel_pad_f32( constant ggml_metal_kargs_pad & args, device const char * src0, From 820438ae2c60aa9c5fd9a20edb085b414139440d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 16 Apr 2026 17:21:28 +0800 Subject: [PATCH 148/249] ggml: add graph_reused (llama/21764) * ggml: add graph_reused * use versioning instead of reuse flag * increment version with atomic * use top bits for split numbering * add assert * move counter to ggml.c * set uid in split_graph only * fix windows * address further review comments * get next_uid rather than doing bit manipulation * rename + add comment about uid --- ggml/src/ggml-backend.cpp | 7 +++++++ ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++ ggml/src/ggml-impl.h | 6 ++++++ ggml/src/ggml.c | 12 ++++++++++++ 5 files changed, 35 insertions(+) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1a555bf2a4d..d9f8aaec52f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1030,6 +1030,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra GGML_ABORT("%s: failed to initialize context\n", __func__); } + graph->uid = ggml_graph_next_uid(); + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; @@ -1477,6 +1479,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } + + // set ids for all splits + for (int i = 0; i < sched->n_splits; ++i) { + sched->splits[i].graph.uid = ggml_graph_next_uid(); + } } static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ad30ecd8fa5..66ed02d2923 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1186,6 +1186,7 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; + uint64_t uid = 0; struct node_properties { ggml_tensor node; void * node_src_data_ptrs[GGML_MAX_SRC]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 790f53cead7..de579d2ed50 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3108,6 +3108,15 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx const void * graph_key = ggml_cuda_graph_get_key(cgraph); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (cgraph->uid != 0 && + cgraph->uid == graph->uid) { + GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid); + GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes); + return false; + } + + graph->uid = cgraph->uid; + // Check if the graph size has changed if ((int)graph->node_props.size() != cgraph->n_nodes) { res = true; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 0639db362e7..62b76abbcec 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -30,6 +30,8 @@ extern "C" { void ggml_print_backtrace(void); +uint64_t ggml_graph_next_uid(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif @@ -338,6 +340,10 @@ struct ggml_cgraph { struct ggml_hash_set visited_hash_set; enum ggml_cgraph_eval_order order; + + // an optional identifier that can be utilized to recognize same graphs if two non-zero values match + // a value of 0 means it is not set and should be ignored + uint64_t uid; }; // returns a slice of cgraph with nodes [i0, i1) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0142498d967..eda041f4518 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -53,6 +53,16 @@ #define UNUSED GGML_UNUSED +uint64_t ggml_graph_next_uid(void) { +#ifdef _MSC_VER + static volatile long long counter = 1; + return (uint64_t) _InterlockedIncrement64(&counter) - 1; +#else + static uint64_t counter = 1; + return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED); +#endif +} + // Needed for ggml_fp32_to_bf16_row() #if defined(__AVX512BF16__) #if defined(_MSC_VER) @@ -7098,6 +7108,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.uid =*/ 0, }; ggml_hash_set_reset(&cgraph->visited_hash_set); @@ -7125,6 +7136,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.use_counts =*/ cgraph0->use_counts, /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, + /*.uid =*/ 0 }; return cgraph; From 57a48a485084a7daa2b61b760968dc384a54b354 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Thu, 16 Apr 2026 12:08:33 -0700 Subject: [PATCH 149/249] opencl: add q5_K gemm and gemv kernels for Adreno (llama/21595) --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 326 ++++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 94 ++++- .../kernels/gemm_noshuffle_q5_k_f32.cl | 176 ++++++++++ .../kernels/gemv_noshuffle_q5_k_f32.cl | 326 ++++++++++++++++++ 5 files changed, 922 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 112c2afe821..772fc537494 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -121,6 +121,8 @@ set(GGML_OPENCL_KERNELS gemm_noshuffle_q4_k_f32 gemv_noshuffle_q6_k_f32 gemm_noshuffle_q6_k_f32 + gemv_noshuffle_q5_k_f32 + gemm_noshuffle_q5_k_f32 mul neg norm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a581402300a..b27fbb13a3a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -542,6 +542,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K; + cl_kernel kernel_convert_block_q5_K_noshuffle; + cl_kernel kernel_restore_block_q5_K_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; @@ -730,6 +732,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_k_f32; cl_kernel kernel_gemv_noshuffle_q6_K_f32; cl_kernel kernel_gemm_noshuffle_q6_K_f32; + cl_kernel kernel_gemv_noshuffle_q5_k_f32; + cl_kernel kernel_gemm_noshuffle_q5_k_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -944,6 +948,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); @@ -2794,6 +2800,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); GGML_LOG_CONT("."); } + + // gemv_noshuffle_q5_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); } @@ -5354,7 +5399,17 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); CL_CHECK(err); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; + } + #else cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + #endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); @@ -5362,6 +5417,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -5378,6 +5435,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, extra->size_dm = size_dm; tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort, qh as uchar + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_8b (backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } if (tensor->type == GGML_TYPE_Q6_K) { @@ -5894,6 +5966,57 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = extra->size_q; + size_t size_qh = extra->size_qh; + size_t size_d = extra->size_d; + size_t size_dm = extra->size_dm; + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Reverse transpose q, qh, d, dm + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_8b (backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); @@ -5901,6 +6024,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; @@ -10451,6 +10576,201 @@ static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_q5_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_k = (ggml_tensor_extra_cl_q5_K *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q (CL_R, CL_UNSIGNED_INT32): width = M*K/2/4 + img_fmt = {CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh (CL_R, CL_HALF_FLOAT): width = M*K/16 + img_fmt = {CL_R, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 16; + img_desc.buffer = extra0_q5_k->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations (CL_RGBA, CL_FLOAT): width = K*N/4 + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0) { + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float) / 2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N / 4; + if (height_B == 0) height_B = 1; + int width_B = K / 4; + int padded_height_B = (N + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = {1, 16}; + size_t global_work_size_t[2] = {(size_t)width_B, (size_t)padded_height_B}; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_k->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -10600,6 +10920,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q5_K x fp32 + if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 1bd83d29b3d..39af32d282b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -568,7 +568,9 @@ kernel void kernel_convert_block_q5_K( global uchar * dst_qh, global uchar * dst_s, global half * dst_d, - global half * dst_dm + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); @@ -599,7 +601,9 @@ kernel void kernel_restore_block_q5_K( global uchar * src_s, global half * src_d, global half * src_dm, - global struct block_q5_K * dst + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); @@ -622,6 +626,92 @@ kernel void kernel_restore_block_q5_K( } } +kernel void kernel_convert_block_q5_K_noshuffle( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int l = 0; l < QK_K/8; ++l) { + uchar x0 = 0; + for (int i = 0; i < 8; ++i) { + x0 |= ((b->qh[(l%4)*8+i] >> (l/4)) & 0x01) << i; + } + qh[l] = x0; + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_K_noshuffle( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int g = 0; g < 4; ++g) { + for (int i = 0; i < 8; ++i) { + uchar x0 = 0; + for (int k = 0; k < 8; ++k) { + x0 |= ((qh[4*k+g] >> i) & 0x01) << k; + } + b->qh[g*8+i] = x0; + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..058c0f7edc6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q5_k_f32( + global const ushort * src0_q, + global const uchar * src0_qh, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + uchar4 qh_bits = vload4(0, qh_ptr + (ki/8) * m); + int qh_shift = ki % 8; + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x000F) | (((qh_bits.s0 >> (qh_shift+0)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x000F) | (((qh_bits.s1 >> (qh_shift+0)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x000F) | (((qh_bits.s2 >> (qh_shift+0)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x000F) | (((qh_bits.s3 >> (qh_shift+0)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x00F0) >> 4) | (((qh_bits.s0 >> (qh_shift+1)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x00F0) >> 4) | (((qh_bits.s1 >> (qh_shift+1)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x00F0) >> 4) | (((qh_bits.s2 >> (qh_shift+1)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x00F0) >> 4) | (((qh_bits.s3 >> (qh_shift+1)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x0F00) >> 8) | (((qh_bits.s0 >> (qh_shift+2)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x0F00) >> 8) | (((qh_bits.s1 >> (qh_shift+2)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x0F00) >> 8) | (((qh_bits.s2 >> (qh_shift+2)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x0F00) >> 8) | (((qh_bits.s3 >> (qh_shift+2)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0xF000) >> 12) | (((qh_bits.s0 >> (qh_shift+3)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0xF000) >> 12) | (((qh_bits.s1 >> (qh_shift+3)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0xF000) >> 12) | (((qh_bits.s2 >> (qh_shift+3)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0xF000) >> 12) | (((qh_bits.s3 >> (qh_shift+3)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..c40db166638 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl @@ -0,0 +1,326 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q5_k_f32( + read_only image1d_buffer_t src0_q, + read_only image1d_buffer_t src0_qh, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + uint LINE_STRIDE_A_QH = M / 2; + uint BLOCK_STRIDE_A_QH = NSUBGROUPS * M / 2; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private ushort4 regH; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regH.s0 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 0)).x); + regH.s1 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 1)).x); + regH.s2 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 2)).x); + regH.s3 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 3)).x); + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } +} From b25d5d050b53463ab415e4a6c7039c43bedee571 Mon Sep 17 00:00:00 2001 From: nullname Date: Fri, 17 Apr 2026 04:48:34 +0800 Subject: [PATCH 150/249] hexagon: optimize HMX matmul operations (llama/21071) * optimize hmx_mat_mul functions by calculating row and column tiles upfront * refactor core_dot_chunk_fp16 to use size_t for tile counts and improve readability * wip * set scale outside of loop * wip * refactor core_mma_chunk_fp16 and mat_mul_qk_0_d16a32 to use size_t for tile counts * wip * wip * refactor transfer_output_chunk_fp16_to_fp32 to use size_t for dimensions * refactor core_dot_chunk_fp16 to use size_t for tile row stride calculation * wip * refactor hmx_mat_mul functions to use hvx_vec_splat_f16 for column scales initialization * refactor hmx_mat_mul_permuted_w16a32_batched to streamline scale setting and locking * refactor core_dot_chunk_fp16 to improve tile stride calculations for output * refactor hmx_mat_mul functions to use Q6_V_vsplat_R for column scales initialization * fix compiling error * wip * optimize row and column tile indexing in core_mma_chunk_fp16 function * wip * Revert "wip" This reverts commit cde679eff79c4a28dd2d89d32f710015e09592b6. * Add size limit check for HAP_mmap in htp_iface_mmap and drop_mmap functions * wip --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 96 +++++++++++----------- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 31 ++++++- 3 files changed, 80 insertions(+), 49 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 485ec3f1aa9..dbca8220fab 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -648,9 +648,9 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); - int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; - int n_tot_tiles = n_col_tiles * n_k_tiles; + size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); @@ -678,9 +678,8 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict __builtin_assume(n_dot_tiles > 0); Q6_bias_mxmem2_A((void *)scales); - for (int r = 0; r < n_row_tiles; ++r) { - for (int c = 0; c < n_col_tiles; ++c) { + for (size_t c = 0; c < n_col_tiles; ++c) { Q6_mxclracc_hf(); const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; @@ -738,25 +737,25 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; const HVX_Vector one = hvx_vec_splat_f16(1.0); - for (int r = 0; r < n_rows; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; - int r1 = r % HMX_FP16_TILE_N_ROWS; + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) #pragma unroll(4) - for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { - int c0 = c / HMX_FP16_TILE_N_COLS; - - const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS; - - HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2]; + for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0)); - volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); + volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); if (r + 1 < n_rows) { @@ -794,7 +793,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -926,7 +925,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, params->m, params->k, params->n, group_size, params->ne13, @@ -944,12 +943,15 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + for (int b3 = 0; b3 < params->ne13; ++b3) { for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); // Pre-load activations for all heads in the group (once per m_chunk). // When the source is strided (permuted Q), use 2D DMA to gather @@ -987,10 +989,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); TIMER_START(weight_load); { @@ -1014,11 +1015,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu for (int g = 0; g < group_size; ++g) { TIMER_START(hmx_core); { - const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, - n_row_tiles, n_col_tiles, params->k / 32); + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); } TIMER_STOP(hmx_core); @@ -1030,12 +1029,12 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu TIMER_STOP(output_store); } } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) @@ -1103,7 +1102,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co return -1; } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, @@ -1121,7 +1120,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { // transfer activation matrix chunk into VTCM - size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); TIMER_START(activation_load); { @@ -1159,7 +1159,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co } for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); TIMER_START(weight_load); { @@ -1184,8 +1185,6 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co TIMER_START(hmx_core); { - const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); } TIMER_STOP(hmx_core); @@ -1307,7 +1306,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds return -1; } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, use_pipeline, @@ -1330,7 +1329,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds HAP_compute_res_hmx_lock(ctx->vtcm_rctx); for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { // transfer activation matrix chunk into VTCM - size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); TIMER_START(activation_load); { @@ -1348,7 +1348,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); TIMER_START(weight_load); { @@ -1373,8 +1374,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_START(hmx_core); { - const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); } TIMER_STOP(hmx_core); @@ -1521,14 +1520,16 @@ void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __f Q6_bias_mxmem2_A((void *)col_scales); - for (int i = 0; i < n_row_tiles; ++i) { - for (int j = 0; j < n_col_tiles; ++j) { + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { Q6_mxclracc_hf(); - const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - - __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; if (!zero_init) { Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); @@ -1697,7 +1698,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict v = Q6_V_vror_VR(v, VLEN - 8); } } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 TIMER_DEFINE(fetch); TIMER_DEFINE(act_load); @@ -1715,7 +1716,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); TIMER_START(fetch); // fetch activation block into VTCM @@ -1731,13 +1732,13 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict } // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); { qweight_fetch_task_state_t s; const int blk_start = kk / QK_Q4_0x4x2; const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; @@ -1777,7 +1778,6 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict dma_queue_pop(ctx->dma[0]); // vtcm_scratch0 is used to store the qweight chunk // worker_pool_run_func already returned, so fetch is done - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, n_blk_sz, k_blk_sz, sub_row_stride, weight_type); } diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index fa84b674cd2..79b5ecd2270 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -98,6 +98,8 @@ enum htp_op_code { #define HTP_OP_MAX_VMEM (3221225472u) #endif +#define HTP_MMAP_MAX_VMEM (2147483648u) + enum htp_tensor_flags { HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index d71c97ed292..5091623a653 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -118,7 +118,11 @@ AEEResult htp_iface_close(remote_handle64 handle) { // release the mmaps (if any) for (uint32_t i=0; immap[i].size) { +#if __HVX_ARCH__ > 73 HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#else + HAP_munmap((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#endif ctx->mmap[i].size = 0; ctx->mmap[i].base = NULL; ctx->mmap[i].fd = -1; @@ -173,8 +177,16 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t struct htp_mmap *m = &ctx->mmap[i]; if (!m->size) { FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); - +#if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#else + if (size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#endif if (va == (void*)-1) { FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size); return AEE_EFAILED; @@ -202,7 +214,11 @@ AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { struct htp_mmap *m = &ctx->mmap[i]; if (fd < 0 || m->fd == fd) { FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size); +#if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif m->size = 0; m->base = NULL; m->fd = -1; @@ -526,7 +542,11 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { if (m->size && !m->pinned) { FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); +#if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif m->size = 0; m->base = 0; m->fd = -1; @@ -540,7 +560,16 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { struct htp_mmap *m = &ctx->mmap[i]; if (!m->size) { +#if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#else + if (b->size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) b->size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#endif if (va == (void*)-1) { FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size); abort(); // can't do much else at this point From 77c0630ce64e63a63f16e30d9982608b5f6474fa Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 16 Apr 2026 22:28:33 -0700 Subject: [PATCH 151/249] opencl: refactor q8_0 set_tensor and mul_mat host side dispatch for Adreno (llama/21938) * opencl: refactor q8_0 gemm/gemv Adreno dispatch * opencl: refactor q8_0 set_tensor * opencl: fix whitespace --- ggml/src/ggml-opencl/ggml-opencl.cpp | 361 ++++++++------------------- 1 file changed, 99 insertions(+), 262 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b27fbb13a3a..8bc7ae65a6d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5116,115 +5116,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - // Transpose weights - size_t q_size_bytes = K * M / 4 * sizeof(float); - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->prealloc_quant_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - cl_mem q_d_image1D; - cl_mem qT_d_image1D; - - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - int height_q = M / 4; - int width_q = K / 4 / 4; - kernel = backend_ctx->kernel_transpose_32; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); - - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // Transpose scales - size_t d_size_bytes = M * (K / 32) * 2; - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->prealloc_scales_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - cl_mem d_d_image1D; - cl_mem dT_d_image1D; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_fmt_1d = { CL_R, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32; - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - int height_s = M / 4; - int width_s = K / 32; - - kernel = backend_ctx->kernel_transpose_16_4x1; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); - - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // copy transposed buffer contents to original buffers - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); - - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); + transpose_2d_as_32b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); } // end transpose #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -9956,19 +9849,18 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0->type; - const enum ggml_type src1t = src1->type; - - GGML_ASSERT(src0t == GGML_TYPE_Q8_0); - GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); + GGML_ASSERT(src1->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + GGML_ASSERT(src1->view_offs == 0); GGML_ASSERT(dst->view_offs == 0); @@ -9989,148 +9881,112 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t cl_context context = backend_ctx->context; cl_kernel kernel; - // init CL objects - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; cl_buffer_region region; - cl_mem A_image1d; - cl_mem B_image1d; - cl_mem B_sub_buffer; - cl_mem S_image1d; - // for B transpose - cl_mem B_image1d_trans = nullptr; - cl_mem B_d = nullptr; - - cl_mem D_image1d; - cl_mem D_sub_buffer; int M = ne01; int N = ne1; int K = ne00; - // create an image for A - img_fmt_1d = { CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4; // Divide by 4 for char -> float - img_desc_1d.buffer = extra0_q8_0->q; - A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - // create an image for Scale - img_fmt_1d = { CL_R, CL_HALF_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32; // Block size is 32 - img_desc_1d.buffer = extra0_q8_0->d; - S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - // create a sub_buffer for B - region.origin = (extra1->offset); // + src1->view_offs); - region.size = K * N * sizeof(float); - B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create an image for B from sub_buffer: RGBA (OCL) - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = K * N / 4; - img_desc_1d.buffer = B_sub_buffer; - B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; - // Create subbuffer and image1d_buffer for dst - region.origin = (extrad->offset); // + dst->view_offs; - region.size = M * N * sizeof(float); - D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 4; + img_desc.buffer = extra0_q8_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - img_fmt_1d = {CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * N; - img_desc_1d.buffer = D_sub_buffer; - D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + // create a sub_buffer for B + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - size_t local_work_size[3] = {1, 1, 1}; - size_t global_work_size[3] = {1, 1, 1}; + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - if (N == 1) { kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; int r2 = 1; int r3 = 1; - cl_uint k_arg = 0; - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); size_t wavesize = backend_ctx->adreno_wave_size; - local_work_size[0] = wavesize; - local_work_size[1] = 4; // reduce factor - local_work_size[2] = 1; + size_t local_work_size[] = { wavesize, 4, 1 }; + size_t global_work_size[] = { CEIL_DIV(M, wavesize)*wavesize, 4, 1 }; - global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize; - global_work_size[1] = 4; // reduce factor - global_work_size[2] = 1; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); } else { - cl_ulong offsetd = extrad->offset + dst->view_offs; - int padding; + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; - //how many extra elements beyond multiple of 8 - int extra_elements = N % 8; + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - //how much padding to add - padding = 0; + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; if (extra_elements > 0){ padding = 8 - extra_elements; } - // Specify the starting offset (in bytes) + // subbuffer for transposed activations region.origin = 0; - // Specify the size of the sub-buffer (divide by 2 for FP16) region.size = K * (N + padding) * sizeof(float)/2; backend_ctx->prealloc_act_trans.allocate(context, region.size); - B_d = clCreateSubBuffer( - backend_ctx->prealloc_act_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) - cl_image_desc image_desc_B_d_output = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * (N + padding)/4), - 0, 0, 0, 0, 0, 0, 0, { B_d } - }; - B_image1d_trans = clCreateImage( - context, - 0, - &image_format_B_d_output, - &image_desc_B_d_output, - NULL, - &status); - CL_CHECK(status); + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + // transpose activations int height_B = N/4; if (height_B == 0) { height_B = 1; @@ -10139,58 +9995,39 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t int padded_height_B = (N + padding)/4; kernel = backend_ctx->kernel_transpose_32_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); - size_t local_size_t[2] = { 1, 16 }; - size_t global_size_t[2] = { - static_cast(width_B), - static_cast(padded_height_B) - }; - - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + // gemm kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; - - int N_with_padding = N + padding; + int padded_N = N + padding; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &N_with_padding)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &padded_N)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); - global_work_size[0] = (size_t)(N + 7) / 8; - global_work_size[1] = (size_t)(M + 3) / 4; - global_work_size[2] = 1; - - local_work_size[0] = 2; - local_work_size[1] = 128; - local_work_size[2] = 1; - } + size_t global_work_size[] = { (size_t)CEIL_DIV(N, 8), (size_t)CEIL_DIV(M, 4), 1 }; + size_t local_work_size[] = { 2, 128, 1 }; - // enqueue kernel with profiling - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_image1d)); - CL_CHECK(clReleaseMemObject(S_image1d)); - CL_CHECK(clReleaseMemObject(D_sub_buffer)); - CL_CHECK(clReleaseMemObject(D_image1d)); - if (B_image1d_trans) { - CL_CHECK(clReleaseMemObject(B_image1d_trans)); - } - if (B_d) { - CL_CHECK(clReleaseMemObject(B_d)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); } #else GGML_UNUSED(backend); From 918e0ad20954beaf4a57675749e0a54f12f4233b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 17 Apr 2026 23:24:21 +0800 Subject: [PATCH 152/249] CUDA: use LRU based eviction for cuda graphs (llama/21611) * CUDA: use a ring-buffer for cuda graphs * bump limit to 128 * use LRU eviction * better naming * do periodic clean-up --- ggml/src/ggml-cuda/common.cuh | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 66ed02d2923..ddf50baf495 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1187,6 +1187,7 @@ struct ggml_cuda_graph { bool disable_due_to_gpu_arch = false; bool warmup_complete = false; uint64_t uid = 0; + int64_t last_used_time = 0; struct node_properties { ggml_tensor node; void * node_src_data_ptrs[GGML_MAX_SRC]; @@ -1368,12 +1369,28 @@ struct ggml_backend_cuda_context { // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) std::unordered_map> cuda_graphs; + int64_t last_graph_eviction_sweep = 0; + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + const int64_t time_now = ggml_time_us(); + + // sweep every 5s, evicting cuda graphs unused for >=10s + if (time_now - last_graph_eviction_sweep >= 5'000'000) { + last_graph_eviction_sweep = time_now; + for (auto it = cuda_graphs.begin(); it != cuda_graphs.end(); ) { + if (time_now - it->second->last_used_time >= 10'000'000) { + it = cuda_graphs.erase(it); + } else { + ++it; + } + } + } + auto it = cuda_graphs.find(first_node_ptr); if (it == cuda_graphs.end()) { - cuda_graphs[first_node_ptr] = std::make_unique(); - return cuda_graphs[first_node_ptr].get(); + it = cuda_graphs.emplace(first_node_ptr, std::make_unique()).first; } + it->second->last_used_time = time_now; return it->second.get(); } From cbbe935765b0a4d5f301ff8e8f4636f3b6e94c98 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 17 Apr 2026 09:17:11 -0700 Subject: [PATCH 153/249] ggml-webgpu: fix compiler warnings and refactor FlashAttention encoding (llama/21052) * Update workflows to remove dependence on llvmpipe * Try setting Dawn_DIR * remove c++20 initializers * Move to proper guid * Try avoiding segfaults on vulkan backend process exit * Remove compiler warnings on parameter casting * Fix soft_max and update reg_tile accumulation to f32 for better precision * Refactor flash_attn a bit * remove c++20 initializers and format * Increase div precision for NVIDIA * revert div precision and comment out ggml-ci node for now * Formatting * Try debugging on a failing CI node * Revert "Try debugging on a failing CI node" This reverts commit 1971e33cba919915e12bcfd5828abfbd54ca942e. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 585 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1498 +++++++---------- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 12 +- 3 files changed, 918 insertions(+), 1177 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3de6258c74d..7d9a4403fab 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -390,12 +390,11 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; + uses_logit_softcap == other.uses_logit_softcap; } }; @@ -409,47 +408,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; +struct ggml_webgpu_flash_attn_decisions { + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; -struct ggml_webgpu_flash_attn_shader_decisions { - uint32_t q_tile = 0; +struct ggml_webgpu_flash_attn_vec_decisions { uint32_t kv_tile = 0; uint32_t wg_size = 0; }; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { - return 1u; - } - - // Head-dim specializations used by the tuned vec f16 path. - switch (key.head_dim_qk) { - case 64: - return 2u; - case 96: - return 4u; - case 128: - return 1u; - case 192: - return 2u; - case 576: - return 2u; - default: - return 1u; - } +inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( + const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.kv_type = context.src1->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = kv_direct; + key.has_mask = has_mask; + key.has_sinks = has_sinks; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -471,79 +460,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; } -struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { - ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_reduce"; - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - variant += std::string("_wg") + std::to_string(context.max_wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - struct ggml_webgpu_flash_attn_blk_pipeline_key { - uint32_t q_tile; uint32_t kv_tile; - bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { - return q_tile == other.q_tile && kv_tile == other.kv_tile; - } + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } }; struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_tile); ggml_webgpu_hash_combine(seed, key.kv_tile); return seed; } }; -struct ggml_webgpu_flash_attn_blk_shader_lib_context { - ggml_webgpu_flash_attn_blk_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_blk"; - - defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); - variant += std::string("_qt") + std::to_string(context.key.q_tile); - - defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); - variant += std::string("_kvt") + std::to_string(context.key.kv_tile); - - uint32_t wg_size = 1; - while ((wg_size << 1) <= context.max_wg_size) { - wg_size <<= 1; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - variant += std::string("_wg") + std::to_string(wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -568,6 +498,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_pipeline_key & key) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!key.kv_direct) { + bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + } + if (key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + return kv_tile; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -802,6 +767,8 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_pipelines; std::unordered_map @@ -849,10 +816,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_row_norm_pipeline_key key = { - .op = context.dst->op, - .inplace = context.inplace, - }; + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.inplace = context.inplace; auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -908,9 +874,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, - .vec4 = context.src0->ne[0] % 4 == 0, - .i64_idx = context.src1->type == GGML_TYPE_I64 }; + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -955,7 +922,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1062,10 +1031,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; - ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, - .vectorized = (int) vectorized, - }; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { @@ -1115,8 +1083,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (key.src_type) - { + switch (key.src_type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1136,9 +1103,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC_TYPE=") + type_str); - } + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } } defines.push_back("BYTE_HELPERS"); @@ -1181,7 +1148,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = context.inplace; auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1208,11 +1176,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_solve_tri_pipeline_key key = { - .type = context.dst->type, - .n = (int) context.src0->ne[0], - .k = (int) context.src1->ne[0], - }; + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; auto it = solve_tri_pipelines.find(key); if (it != solve_tri_pipelines.end()) { @@ -1250,10 +1217,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_ssm_conv_pipeline_key key = { - .type = context.dst->type, - .vectorized = context.src1->ne[0] == 4, - }; + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; auto it = ssm_conv_pipelines.find(key); if (it != ssm_conv_pipelines.end()) { @@ -1293,11 +1259,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_gated_delta_net_pipeline_key key = { - .type = context.dst->type, - .s_v = (int) context.src2->ne[0], - .kda = context.src3->ne[0] == context.src2->ne[0], - }; + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; auto it = gated_delta_net_pipelines.find(key); if (it != gated_delta_net_pipelines.end()) { @@ -1330,7 +1295,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { @@ -1357,15 +1323,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_vec_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - }; + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1451,15 +1415,14 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - .use_subgroup_matrix = context.supports_subgroup_matrix - }; + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -1578,8 +1541,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, - .src1_type = context.src1->type }; + ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { @@ -1621,8 +1585,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (context.src0->type) - { + switch (context.src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1642,9 +1605,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } } defines.push_back("BYTE_HELPERS"); @@ -1689,10 +1652,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_id_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - }; + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -1782,13 +1744,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), - }; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = context.inplace; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -1853,13 +1814,12 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, - .src_overlap = context.src_overlap, - }; + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = context.inplace; + key.overlap = context.overlap; + key.src_overlap = context.src_overlap; auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -1908,9 +1868,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_concat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -1945,9 +1904,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_repeat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { @@ -1985,16 +1943,16 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { - auto it = flash_attn_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; std::string variant = "flash_attn"; - switch (context.key.kv_type) { + switch (key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -2010,111 +1968,206 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.key.kv_type); + variant += std::string("_") + ggml_type_name(key.kv_type); - if (context.key.has_mask) { + if (key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.key.has_sinks) { + if (key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.key.uses_logit_softcap) { + if (key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (context.key.kv_direct) { + if (key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.has_mask && context.key.use_vec) { - defines.push_back("BLK"); - variant += "_blk"; - } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec) { - q_tile = 1; - kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - } - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + auto decisions = std::make_shared(); + decisions->q_tile = context.sg_mat_m; + + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + decisions->kv_tile = kv_tile; + decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - uint32_t wg_size = 0; - if (context.key.use_vec) { - wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - } else { - wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + std::vector defines; + std::string variant = "flash_attn_vec"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + defines.push_back("BLK"); + variant += "_mask_blk"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + defines.push_back("Q_TILE=1"); + + auto decisions = std::make_shared(); + decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + uint32_t vec_ne = 1u; + + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; - flash_attn_pipelines[context.key] = pipeline; - return flash_attn_pipelines[context.key]; - } - - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - auto it = flash_attn_blk_pipelines.find(context.key); + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_blk_pipelines[context.key] = pipeline; - return flash_attn_blk_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_reduce_pipeline( - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - auto it = flash_attn_vec_reduce_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_vec_reduce_pipelines[context.key] = pipeline; - return flash_attn_vec_reduce_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_cpy_pipeline_key key = { - .src_type = context.src0->type, - .dst_type = context.dst->type, - }; + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; auto it = cpy_pipelines.find(key); if (it != cpy_pipelines.end()) { @@ -2166,11 +2219,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_glu_pipeline_key key = { - .glu_op = ggml_get_glu_op(context.dst), - .type = context.dst->type, - .split = (context.src1 != nullptr), - }; + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { @@ -2239,11 +2291,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_rope_pipeline_key key = { - .type = context.dst->type, - .inplace = context.inplace, - .has_ff = (context.src2 != nullptr), - }; + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; + key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); if (it != rope_pipelines.end()) { @@ -2288,12 +2339,11 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_soft_max_pipeline_key key = { - .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, - .has_mask = (context.src1 != nullptr), - .has_sink = (context.src2 != nullptr), - .inplace = context.inplace, - }; + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = context.inplace; auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2359,25 +2409,6 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } - - static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; - } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 01637e2ddab..e7bda817a28 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -41,6 +41,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim wg_x = CEIL_DIV(total_wg, wg_y); } +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -369,6 +375,96 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + + return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +} + +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst + bool src_overlap; +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + + return flags; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -480,10 +576,8 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = ctx->param_arena.buffer, - .offset = param_offset, - .size = ctx->param_arena.slot_size }); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); @@ -502,13 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (size_t i = 0; i < dispatches.size(); i++) { GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); - const uint32_t query_begin = ctx->profile_timestamp_query_count++; - const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, - .beginningOfPassWriteIndex = query_begin, - .endOfPassWriteIndex = query_end }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); @@ -544,17 +642,19 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); - entries.push_back( - { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); @@ -632,65 +732,11 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { delete backend; } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); -} - -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} - -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} - -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; - -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); - - return flags; -} - static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); @@ -712,14 +758,8 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -732,13 +772,12 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); @@ -772,29 +811,21 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, std::vector entries; uint32_t binding_index = 0; if (!inplace) { - entries.push_back({ .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); binding_index++; } - entries.push_back({ .binding = binding_index, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - entries.push_back({ .binding = binding_index + 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); @@ -832,14 +863,8 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -850,13 +875,12 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); @@ -888,18 +912,9 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); @@ -911,12 +926,11 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -944,18 +958,9 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); @@ -971,15 +976,14 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src4, ggml_tensor * src5, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .src3 = src3, - .src4 = src4, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); @@ -1015,34 +1019,10 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(src3), - .offset = ggml_webgpu_tensor_align_offset(ctx, src3), - .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(src4), - .offset = ggml_webgpu_tensor_align_offset(ctx, src4), - .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(src5), - .offset = ggml_webgpu_tensor_align_offset(ctx, src5), - .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, - { .binding = 6, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), }; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); @@ -1058,12 +1038,11 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct return std::nullopt; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = idx, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -1086,25 +1065,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); } uint32_t threads; @@ -1131,12 +1099,11 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = WEBGPU_MAX_WG_SIZE, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1160,20 +1127,9 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); @@ -1225,17 +1181,16 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline webgpu_pipeline pipeline; @@ -1270,18 +1225,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Build bind group entries std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; // Calculate workgroup dimensions @@ -1333,13 +1279,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; @@ -1380,22 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id_gather.wgsl std::vector gather_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1427,30 +1364,18 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id.wgsl std::vector main_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; // Calculate workgroup dimensions @@ -1486,11 +1411,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * mask, ggml_tensor * sinks, ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1522,86 +1445,53 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; uint32_t binding_index = 3; if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - - const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && - (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .use_vec = use_vec, - }; - - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - - auto * decisions = static_cast(pipeline.context.get()); - - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); + webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : + ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + + if (!use_vec) { + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } + + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1609,197 +1499,162 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_vec) { - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - const bool use_vec_reduce = nwg > 1u; - GGML_ASSERT(nrows <= UINT32_MAX); - - uint64_t tmp_stats_base = 0; - uint64_t tmp_size_bytes = 0; - wgpu::Buffer tmp_buf = {}; - uint64_t tmp_bind_offset = 0; - uint64_t tmp_bind_size = 0; - const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); - - if (use_vec_reduce) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - tmp_stats_base = tmp_data_elems; - tmp_size_bytes = - ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - GGML_ASSERT(tmp_stats_base <= UINT32_MAX); - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = scratch_offset; - tmp_bind_size = tmp_size_bytes; - scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); - } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); - } - - webgpu_pipeline blk_pipeline; - std::vector blk_params; - std::vector blk_entries; - if (use_blk) { - GGML_ASSERT(has_mask); - - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = - { - .q_tile = decisions->q_tile, - .kv_tile = decisions->kv_tile, - }, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); - - blk_params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) K->ne[1], // seq_len_kv - stride_mask3, // stride_mask3 - blk_nblk0, // nblk0 - blk_nblk1, // nblk1 - }; - blk_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, - }; - scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); - } + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } - std::vector split_params = params; - if (use_blk) { - split_params.push_back(0u); // blk_base - split_params.push_back(blk_nblk0); // blk_nblk0 - split_params.push_back(blk_nblk1); // blk_nblk1 - } - split_params.push_back(0u); // tmp_data_base - split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base - split_params.push_back(nwg); // nwg - - std::vector split_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 }; - uint32_t split_binding_index = 3; - if (has_mask) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - if (use_blk) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = blk_buf, - .offset = blk_entries[1].offset, - .size = blk_size_bytes }); - } + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V)), + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (has_mask) { split_entries.push_back( - { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline reduce_pipeline; - std::vector reduce_params; - std::vector reduce_entries; - if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, - std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = - { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }, - .max_wg_size = reduce_wg_size, - }; - reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); - - reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base - }; - - reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - } + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; - const uint64_t split_wg_total = (uint64_t) wg_x * nwg; - GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector dispatches; - - if (use_blk) { - dispatches.push_back({ - blk_pipeline, - std::move(blk_params), - std::move(blk_entries), - { blk_nblk0, blk_nblk1 * blk_batch_count } - }); - } + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } + + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + + std::vector dispatches; + + if (has_mask) { dispatches.push_back({ - pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } }); - if (use_vec_reduce) { - dispatches.push_back({ - reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } - }); - } - - return ggml_backend_webgpu_build_multi(ctx, dispatches); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #endif // __EMSCRIPTEN__ @@ -1807,13 +1662,12 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); @@ -1844,10 +1698,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor float alpha_p = ggml_get_op_params_f32(dst, 2); float beta = ggml_get_op_params_f32(dst, 3); float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); break; } default: @@ -1856,25 +1710,19 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor } else if (dst->op == GGML_OP_CLAMP) { float clamp_min = ggml_get_op_params_f32(dst, 0); float clamp_max = ggml_get_op_params_f32(dst, 1); - params.push_back(*reinterpret_cast(&clamp_min)); - params.push_back(*reinterpret_cast(&clamp_max)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); } else if (dst->op == GGML_OP_FILL) { float fill_val = ggml_get_op_params_f32(dst, 0); - params.push_back(*reinterpret_cast(&fill_val)); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); effective_src = dst; // fill simply fills dst } std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(effective_src), - .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), - .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -1887,15 +1735,14 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = flags.inplace, - .overlap = flags.overlap, - .src_overlap = flags.src_overlap, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = flags.inplace; + shader_lib_ctx.overlap = flags.overlap; + shader_lib_ctx.src_overlap = flags.src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1944,38 +1791,18 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = merged_offset, - .size = merged_end - merged_offset, - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = src0_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = src1_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); if (!flags.inplace && !flags.overlap) { - entries.push_back({ - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2012,26 +1839,16 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2059,21 +1876,14 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * (uint32_t) (dst->ne[2]) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2097,28 +1907,19 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); @@ -2129,14 +1930,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); @@ -2187,41 +1987,27 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2232,12 +2018,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); @@ -2265,29 +2050,20 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -2296,13 +2072,12 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2321,23 +2096,15 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; // bindgroups unchanged - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2349,25 +2116,23 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -2389,39 +2154,29 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[2], has_mask ? (uint32_t) src1->ne[2] : 0, has_mask ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; + std::vector entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; if (has_mask) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); binding_num++; } if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); binding_num++; } if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); @@ -2432,20 +2187,13 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); @@ -2455,13 +2203,12 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); @@ -2527,11 +2274,8 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * const uint32_t wg_x_init = std::min(total_wg_init, max_wg); const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); std::vector init_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; dispatches.push_back({ @@ -2580,12 +2324,9 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * nrows }; std::vector merge_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, - { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; const uint32_t total_wg_merge = nm * nrows; @@ -2607,23 +2348,14 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); @@ -2641,20 +2373,13 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor total_sum ? 1 : (uint32_t) src->ne[1], total_sum ? 1 : (uint32_t) src->ne[2] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); @@ -3133,40 +2858,24 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * mask = tensor->src[3]; const ggml_tensor * sinks = tensor->src[4]; if (Q && K && V) { - GGML_UNUSED(sinks); - const bool kv_direct = (K->type == GGML_TYPE_F16) && - (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && - kv_vec_type_supported && (V->type == K->type); - if (use_vec) { - const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - const size_t limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!kv_direct) { - bytes_per_kv += std::max(Q->ne[0], V->ne[0]); - } - if (mask != nullptr) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; - if (kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= sg_mat_n; - } - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = const_cast(Q); + shader_lib_ctx.src1 = const_cast(K); + shader_lib_ctx.src2 = const_cast(V); + shader_lib_ctx.src3 = const_cast(mask); + shader_lib_ctx.src4 = const_cast(sinks); + shader_lib_ctx.dst = const_cast(tensor); + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { + const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3271,8 +2980,9 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct } static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *) guid_str); + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; } static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { @@ -3931,20 +3641,23 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - static ggml_backend_webgpu_reg_context ctx; - static ggml_backend_reg reg = { + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, + /* .context = */ ctx, }; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 0; + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend // registry. Recreating it on repeated registry lookups can invalidate // adapter/device references that are still held by the backend/device layer. - if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { return ® } @@ -3961,17 +3674,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); - ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); - ctx.webgpu_global_ctx->instance = std::move(inst); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); // Probe for adapter support wgpu::Adapter adapter; - if (ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - ctx.webgpu_global_ctx->instance.WaitAny( - ctx.webgpu_global_ctx->instance.RequestAdapter( + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { @@ -3984,7 +3698,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } if (adapter != nullptr) { - ctx.device_count = 1; + ctx->device_count = 1; } return ® diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 82d072be73a..61107c6a985 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -1,7 +1,6 @@ diagnostic(off, subgroup_uniformity); enable f16; -#define Q_TILE 1 #define KV_TILE 32 #define WG_SIZE 32 @@ -11,7 +10,7 @@ struct Params { seq_len_kv: u32, stride_mask3: u32, // Number of KV blocks and Q blocks per batch. - // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. nblk0: u32, nblk1: u32, }; @@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, return; } - let q_start = q_blk * Q_TILE; + let q_start = q_blk; let k_start = kv_blk * KV_TILE; let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); @@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var local_max = -MASK_MAX; var local_any = 0u; - for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { - let q_row = q_start + q_rel; - if (q_row >= params.seq_len_q) { - continue; - } + let q_row = q_start; + if (q_row < params.seq_len_q) { let row_base = mask_batch_base + q_row * params.seq_len_kv; for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { let k_col = k_start + k_rel; From a899e4bdcbda94e099c7a2ac40ff26b490419cca Mon Sep 17 00:00:00 2001 From: SamareshSingh <97642706+ssam18@users.noreply.github.com> Date: Sat, 18 Apr 2026 03:04:51 -0500 Subject: [PATCH 154/249] ggml-backend-meta: add multi-segment read support in get_tensor (llama/22063) --- ggml/src/ggml-backend-meta.cpp | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 1ee3eeb4d96..24f6bc0639d 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1270,7 +1270,45 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co GGML_ASSERT(ggml_is_contiguous(tensor)); const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - GGML_ASSERT(split_state.n_segments == 1); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[1] == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[2] == size); + return; + } switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: From 32789b9e07afc115eec3be81a76a34453e90ae67 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Sun, 19 Apr 2026 10:21:53 +0300 Subject: [PATCH 155/249] rpc : refactor the RPC transport (llama/21998) * rpc : refactor the RPC transport Move all transport related code into a separate file and use the socket_t interface to hide all transport implementation details. * fix win32 * better socket_t construction --- ggml/src/ggml-rpc/CMakeLists.txt | 1 + ggml/src/ggml-rpc/ggml-rpc.cpp | 806 +++---------------------------- ggml/src/ggml-rpc/transport.cpp | 683 ++++++++++++++++++++++++++ ggml/src/ggml-rpc/transport.h | 34 ++ 4 files changed, 782 insertions(+), 742 deletions(-) create mode 100644 ggml/src/ggml-rpc/transport.cpp create mode 100644 ggml/src/ggml-rpc/transport.h diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index 8671ce5ceaf..40e11fead63 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -2,6 +2,7 @@ message(STATUS "Using RPC backend") ggml_add_backend_library(ggml-rpc ggml-rpc.cpp + transport.cpp ) if (WIN32) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 017ef0af360..2ded7397868 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include "transport.h" #include #include @@ -12,35 +13,11 @@ #include #include #include -#ifdef _WIN32 -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -# include -# include -# include -# include -#endif #include #include #include #include -#ifdef GGML_RPC_RDMA -# include -# include -# ifndef _WIN32 -# include -# endif -#endif // GGML_RPC_RDMA - static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -49,128 +26,6 @@ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); namespace fs = std::filesystem; -static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB - -#ifdef _WIN32 -typedef SOCKET sockfd_t; -using ssize_t = __int64; -#else -typedef int sockfd_t; -#endif - -// cross-platform socket - -#ifdef GGML_RPC_RDMA -static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) -static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB -static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes -using rdma_gid_t = std::array; - -struct rdma_conn { - struct ibv_context * ctx = nullptr; - struct ibv_pd * pd = nullptr; - struct ibv_cq * scq = nullptr; // send completions - struct ibv_cq * rcq = nullptr; // recv completions - struct ibv_qp * qp = nullptr; - - void * tx_buf = nullptr; - struct ibv_mr * tx_mr = nullptr; - - void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous - struct ibv_mr * rx_mr = nullptr; - int rx_head = 0; - - uint32_t max_inline = 0; - - uint8_t * rx_slot(int i) const { - return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; - } - - bool post_rx(int i) { - struct ibv_sge sge = {}; - sge.addr = (uintptr_t)rx_slot(i); - sge.length = RDMA_CHUNK; - sge.lkey = rx_mr->lkey; - struct ibv_recv_wr wr = {}, * bad = nullptr; - wr.wr_id = (uint64_t)i; - wr.sg_list = &sge; - wr.num_sge = 1; - return ibv_post_recv(qp, &wr, &bad) == 0; - } - - ~rdma_conn() { - if (tx_mr) ibv_dereg_mr(tx_mr); - if (rx_mr) ibv_dereg_mr(rx_mr); - free(tx_buf); - free(rx_buf); - if (qp) ibv_destroy_qp(qp); - if (scq) ibv_destroy_cq(scq); - if (rcq) ibv_destroy_cq(rcq); - if (pd) ibv_dealloc_pd(pd); - if (ctx) ibv_close_device(ctx); - } -}; - -// Local RDMA parameters captured during the probe phase and later consumed -// by rdma_activate() after the remote side's caps arrive via HELLO. -struct rdma_local_info { - uint32_t qpn = 0; - uint32_t psn = 0; - uint8_t gid[RDMA_GID_SIZE] = {}; - uint8_t ib_port = 0; - int gid_idx = 0; - enum ibv_mtu path_mtu = IBV_MTU_1024; -}; -#endif // GGML_RPC_RDMA - -// conn_caps size for transport-agnostic capability exchange -static constexpr size_t RPC_CONN_CAPS_SIZE = 24; - -// conn_caps RDMA layout helper -#ifdef GGML_RPC_RDMA -struct rdma_caps { - uint32_t qpn; - uint32_t psn; - uint8_t gid[RDMA_GID_SIZE]; -}; -static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); -#endif // GGML_RPC_RDMA - -// Forward declarations for transport function pointers -struct socket_t; -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); - -struct socket_t { - sockfd_t fd; - bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; - bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; -#ifdef GGML_RPC_RDMA - std::unique_ptr rdma; - rdma_local_info rdma_local = {}; -#endif // GGML_RPC_RDMA - socket_t(sockfd_t fd) : fd(fd) {} - ~socket_t() { -#ifdef GGML_RPC_RDMA - rdma.reset(); -#endif // GGML_RPC_RDMA - LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); -#ifdef _WIN32 - if (fd != INVALID_SOCKET) closesocket(this->fd); -#else - if (fd >= 0) close(this->fd); -#endif - } - - // Advertise local transport capabilities into conn_caps. - // May probe RDMA and store the probe on this socket for update_caps. - void get_caps(uint8_t * caps); - - // Activate transport upgrade based on remote conn_caps using the probe - // previously stored by get_caps. - void update_caps(const uint8_t * remote_caps); -}; - // macro for nicer error messages on server crash #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") @@ -403,540 +258,27 @@ static uint64_t fnv_hash(const uint8_t * data, size_t len) { return hash; } -static std::shared_ptr make_socket(sockfd_t fd) { -#ifdef _WIN32 - if (fd == INVALID_SOCKET) { - return nullptr; - } -#else - if (fd < 0) { - return nullptr; - } -#endif - return std::make_shared(fd); -} - -static bool set_no_delay(sockfd_t sockfd) { - int flag = 1; - // set TCP_NODELAY to disable Nagle's algorithm - int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static bool set_reuse_addr(sockfd_t sockfd) { - int flag = 1; - int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static std::shared_ptr socket_connect(const char * host, int port) { - struct sockaddr_in addr; - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock_ptr = make_socket(sockfd); - if (sock_ptr == nullptr) { - return nullptr; - } - if (!set_no_delay(sockfd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - struct hostent * server = gethostbyname(host); - if (server == NULL) { - GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); - return nullptr; - } - memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); - if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - return nullptr; - } - return sock_ptr; -} - -static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { - auto client_socket_fd = accept(srv_sockfd, NULL, NULL); - auto client_socket = make_socket(client_socket_fd); - if (client_socket == nullptr) { - return nullptr; - } - if (!set_no_delay(client_socket_fd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - return client_socket; -} - -static std::shared_ptr create_server_socket(const char * host, int port) { - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock = make_socket(sockfd); - if (sock == nullptr) { - return nullptr; - } - if (!set_reuse_addr(sockfd)) { - GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); - return nullptr; - } - if (inet_addr(host) == INADDR_NONE) { - GGML_LOG_ERROR("Invalid host address: %s\n", host); - return nullptr; - } - struct sockaddr_in serv_addr; - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = inet_addr(host); - serv_addr.sin_port = htons(port); - - if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return nullptr; - } - if (listen(sockfd, 1) < 0) { - return nullptr; - } - return sock; -} - -static bool send_data(sockfd_t sockfd, const void * data, size_t size) { - size_t bytes_sent = 0; - while (bytes_sent < size) { - size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); - ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0); - if (n < 0) { - GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", - bytes_sent, size_to_send); - return false; - } - bytes_sent += (size_t)n; - } - return true; -} - -static bool recv_data(sockfd_t sockfd, void * data, size_t size) { - size_t bytes_recv = 0; - while (bytes_recv < size) { - size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); - ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0); - if (n < 0) { - GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", - bytes_recv, size_to_recv); - return false; - } - if (n == 0) { - LOG_DBG("recv returned 0 (peer closed?)\n"); - return false; - } - bytes_recv += (size_t)n; - } - return true; -} - -// TCP transport implementations (for function-pointer dispatch) - -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { - return send_data(sock->fd, data, size); -} - -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { - return recv_data(sock->fd, data, size); -} - -// RDMA transport (performance-optimized, auto-negotiated) - -#ifdef GGML_RPC_RDMA - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); - -static inline bool tcp_peer_closed(int fd) { - if (fd < 0) return false; -#ifndef _WIN32 - struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; - int r = poll(&pfd, 1, 0); - return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); -#else - return false; -#endif -} - -static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) { - for (uint64_t s = 0; ; s++) { - int n = ibv_poll_cq(cq, 1, wc); - if (n > 0) { - if (wc->status != IBV_WC_SUCCESS) { - GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", - wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); - } - return wc->status == IBV_WC_SUCCESS; - } - if (n < 0) return false; - if ((s & 0xFFFFF) == 0 && s > 0) { - if (tcp_peer_closed(tcp_fd)) { - return false; - } - } - } -} - -static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) { - const uint8_t * src = (const uint8_t *)data; - size_t rem = size; - while (rem > 0) { - size_t chunk = std::min(rem, RDMA_CHUNK); - - struct ibv_sge sge = {}; - struct ibv_send_wr wr = {}, * bad = nullptr; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - - if (chunk <= c->max_inline) { - sge.addr = (uintptr_t)src; - sge.length = chunk; - wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; - } else { - memcpy(c->tx_buf, src, chunk); - sge.addr = (uintptr_t)c->tx_buf; - sge.length = chunk; - sge.lkey = c->tx_mr->lkey; - wr.send_flags = IBV_SEND_SIGNALED; - } - - if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; - struct ibv_wc wc; - if (!rdma_poll(c->scq, &wc, tcp_fd)) return false; - - src += chunk; - rem -= chunk; - } - return true; -} - - -static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) { - uint8_t * dst = (uint8_t *)data; - size_t rem = size; - while (rem > 0) { - struct ibv_wc wc; - if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false; - - int slot = (int)wc.wr_id; - size_t got = wc.byte_len; - memcpy(dst, c->rx_slot(slot), got); - - if (!c->post_rx(slot)) return false; - - dst += got; - rem -= got; - } - return true; -} - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { - return rdma_send(sock->rdma.get(), data, size, sock->fd); -} - -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { - return rdma_recv(sock->rdma.get(), data, size, sock->fd); -} - -// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. -// Used to match the socket's local IP against the kernel's GID table so that -// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: -// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) -// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) -// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is -// Returns std::nullopt on unsupported family or getsockname failure. -static std::optional rdma_build_target_gid(sockfd_t tcp_fd) { - sockaddr_storage addr = {}; - socklen_t addr_len = sizeof(addr); - if (getsockname(tcp_fd, reinterpret_cast(&addr), &addr_len) != 0) { - return std::nullopt; - } - rdma_gid_t target = {}; - if (addr.ss_family == AF_INET) { - const auto * a = reinterpret_cast(&addr); - target[10] = 0xff; - target[11] = 0xff; - memcpy(&target[12], &a->sin_addr, 4); - return target; - } - if (addr.ss_family == AF_INET6) { - const auto * a = reinterpret_cast(&addr); - memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); - return target; - } - return std::nullopt; -} - -static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { - const char * dev_env = std::getenv("GGML_RDMA_DEV"); - const char * gid_env = std::getenv("GGML_RDMA_GID"); - - auto target_gid = rdma_build_target_gid(tcp_fd); - if (!target_gid) { - return nullptr; - } - - const uint8_t ib_port = 1; - int num_devs = 0; - ibv_device ** devs = ibv_get_device_list(&num_devs); - if (!devs || num_devs == 0) return nullptr; - - ibv_context * ibctx = nullptr; - const char * matched_dev = nullptr; - int gid_idx = gid_env ? atoi(gid_env) : -1; - int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB - - for (int d = 0; d < num_devs; d++) { - const char * dn = ibv_get_device_name(devs[d]); - if (dev_env && strcmp(dev_env, dn) != 0) continue; - - ibv_context * ctx = ibv_open_device(devs[d]); - if (!ctx) continue; - - ibv_port_attr pa; - if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } - - int found_gid = gid_idx; - int found_version = IBV_GID_TYPE_IB; - if (found_gid < 0) { - // Find a GID on this port whose bytes equal the local TCP address - // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 - // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths - // are avoided. ibv_query_gid_ex returns gid+type in one call. - int v2_idx = -1; - int v1_idx = -1; - for (int i = 0; i < pa.gid_tbl_len; i++) { - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; - if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; - if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { - v2_idx = i; - } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { - v1_idx = i; - } - } - if (v2_idx >= 0) { - found_gid = v2_idx; - found_version = IBV_GID_TYPE_ROCE_V2; - } else if (v1_idx >= 0) { - found_gid = v1_idx; - found_version = IBV_GID_TYPE_ROCE_V1; - } - } else { - // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { - found_version = entry.gid_type; - } - } - if (found_gid >= 0) { - ibctx = ctx; - gid_idx = found_gid; - gid_version = found_version; - matched_dev = dn; - out->path_mtu = pa.active_mtu; - break; - } - ibv_close_device(ctx); - } - ibv_free_device_list(devs); - if (!ibctx) return nullptr; - - out->ib_port = ib_port; - out->gid_idx = gid_idx; - - // unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(), - // so each failure path is a plain `return nullptr;`. - auto c = std::make_unique(); - c->ctx = ibctx; - - c->pd = ibv_alloc_pd(ibctx); - if (!c->pd) return nullptr; - - c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); - c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); - if (!c->scq || !c->rcq) return nullptr; - - ibv_qp_init_attr qia = {}; - qia.send_cq = c->scq; - qia.recv_cq = c->rcq; - qia.qp_type = IBV_QPT_RC; - qia.cap.max_send_wr = 4; - qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; - qia.cap.max_send_sge = 1; - qia.cap.max_recv_sge = 1; - qia.cap.max_inline_data = 256; - - c->qp = ibv_create_qp(c->pd, &qia); - if (!c->qp) return nullptr; - c->max_inline = qia.cap.max_inline_data; - - c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); - c->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); - if (!c->tx_buf || !c->rx_buf) return nullptr; - - c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); - c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - if (!c->tx_mr || !c->rx_mr) return nullptr; - - ibv_gid local_gid; - if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr; - - out->qpn = c->qp->qp_num; - out->psn = c->qp->qp_num & 0xffffff; - memcpy(out->gid, &local_gid, RDMA_GID_SIZE); - - const char * ver_str = ""; - if (gid_version == IBV_GID_TYPE_ROCE_V2) { - ver_str = " RoCEv2"; - } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { - ver_str = " RoCEv1"; - } - GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", - matched_dev, gid_idx, ver_str, out->qpn, c->max_inline); - return c.release(); -} - -// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. -// On success, the connection is live and ready for rdma_send/rdma_recv. -static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, - uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { - // RESET -> INIT - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_INIT; - a.port_num = local->ib_port; - a.pkey_index = 0; - a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - return false; - } - } - - for (int i = 0; i < RDMA_RX_DEPTH; i++) { - if (!c->post_rx(i)) return false; - } - - // INIT -> RTR - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTR; - a.path_mtu = local->path_mtu; - a.dest_qp_num = remote_qpn; - a.rq_psn = remote_psn; - a.max_dest_rd_atomic = 1; - a.min_rnr_timer = 1; - a.ah_attr.is_global = 1; - memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); - a.ah_attr.grh.hop_limit = 1; - a.ah_attr.grh.sgid_index = local->gid_idx; - a.ah_attr.dlid = 0; - a.ah_attr.port_num = local->ib_port; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | - IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { - return false; - } - } - - // RTR -> RTS - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTS; - a.timeout = 14; - a.retry_cnt = 7; - a.rnr_retry = 7; - a.sq_psn = local->psn; - a.max_rd_atomic = 1; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | - IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { - return false; - } - } - - GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", - local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); - return true; -} - -#endif // GGML_RPC_RDMA - -// --------------------------------------------------------------------------- -// socket_t transport capability methods -// --------------------------------------------------------------------------- - -void socket_t::get_caps(uint8_t * caps) { - memset(caps, 0, RPC_CONN_CAPS_SIZE); -#ifdef GGML_RPC_RDMA - rdma_local = {}; - rdma.reset(rdma_probe(fd, &rdma_local)); - if (rdma) { - rdma_caps rc = {}; - rc.qpn = rdma_local.qpn; - rc.psn = rdma_local.psn; - memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); - memcpy(caps, &rc, sizeof(rc)); - } -#endif // GGML_RPC_RDMA -} - -void socket_t::update_caps(const uint8_t * remote_caps) { -#ifdef GGML_RPC_RDMA - if (!rdma) { - return; - } - rdma_caps rc = {}; - memcpy(&rc, remote_caps, sizeof(rc)); - if (rc.qpn == 0) { - rdma.reset(); - return; - } - if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) { - fn_send = rdma_send_impl; - fn_recv = rdma_recv_impl; - } else { - GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); - rdma.reset(); - } -#else - (void)remote_caps; -#endif // GGML_RPC_RDMA -} - -// unified transport dispatch (via function pointers) - -static bool send_data(socket_t * sock, const void * data, size_t size) { - return sock->fn_send(sock, data, size); -} - -static bool recv_data(socket_t * sock, void * data, size_t size) { - return sock->fn_recv(sock, data, size); -} - -static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { - if (!send_data(sock, &msg_size, sizeof(msg_size))) { +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { return false; } - return send_data(sock, msg, msg_size); + return sock->send_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sock, msg, msg_size); + return sock->recv_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, std::vector & input) { +static bool recv_msg(socket_ptr sock, std::vector & input) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } try { @@ -945,7 +287,7 @@ static bool recv_msg(socket_t * sock, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sock, input.data(), size); + return sock->recv_data(input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -964,15 +306,15 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock.get(), &input_size, sizeof(input_size))) { + if (!sock->send_data(&input_size, sizeof(input_size))) { return false; } - if (!send_data(sock.get(), input, input_size)) { + if (!sock->send_data(input, input_size)) { return false; } return true; @@ -980,18 +322,18 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } uint64_t out_size; - if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { + if (!sock->recv_data(&out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock.get(), output, output_size)) { + if (!sock->recv_data(output, output_size)) { return false; } return true; @@ -1025,7 +367,6 @@ static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); static std::unordered_map> sockets; - static bool initialized = false; auto it = sockets.find(endpoint); if (it != sockets.end()) { @@ -1040,19 +381,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { return nullptr; } -#ifdef _WIN32 - if (!initialized) { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - initialized = true; + if (!rpc_transport_init()) { + return nullptr; } -#else - GGML_UNUSED(initialized); -#endif - auto sock = socket_connect(host.c_str(), port); + auto sock = socket_t::connect(host.c_str(), port); if (sock == nullptr) { return nullptr; } @@ -2110,10 +1442,10 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - socket_t * sockfd) { + socket_ptr sock) { rpc_server server(backends, cache_dir); uint8_t cmd; - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { return; } if (cmd != RPC_CMD_HELLO) { @@ -2123,7 +1455,7 @@ static void rpc_serve_client(const std::vector & backends, const // Read input_size and validate protocol version uint64_t hello_input_size; - if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { return; } @@ -2134,24 +1466,22 @@ static void rpc_serve_client(const std::vector & backends, const } rpc_msg_hello_req req = {}; - if (!recv_data(sockfd, &req, sizeof(req))) { + if (!sock->recv_data(&req, sizeof(req))) { return; } rpc_msg_hello_rsp rsp = {}; server.hello(rsp); - // Advertise server transport capabilities based on client's caps - sockfd->get_caps(rsp.conn_caps); - - if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { return; } // Activate transport upgrade using client's caps - sockfd->update_caps(req.conn_caps); + sock->update_caps(req.conn_caps); while (true) { - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { break; } if (cmd >= RPC_CMD_COUNT) { @@ -2165,115 +1495,115 @@ static void rpc_serve_client(const std::vector & backends, const return; } case RPC_CMD_DEVICE_COUNT: { - if (!recv_msg(sockfd, nullptr, 0)) { + if (!recv_msg(sock, nullptr, 0)) { return; } rpc_msg_device_count_rsp response; response.device_count = backends.size(); - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; if (!server.alloc_buffer(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALLOC_SIZE: { rpc_msg_get_alloc_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alloc_size_rsp response; if (!server.get_alloc_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALIGNMENT: { rpc_msg_get_alignment_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; if (!server.get_alignment(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { rpc_msg_get_max_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; if (!server.get_max_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_BUFFER_GET_BASE: { rpc_msg_buffer_get_base_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_buffer_get_base_rsp response; if (!server.buffer_get_base(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_FREE_BUFFER: { rpc_msg_free_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.free_buffer(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_BUFFER_CLEAR: { rpc_msg_buffer_clear_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.buffer_clear(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_SET_TENSOR: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.set_tensor(input)) { @@ -2283,62 +1613,62 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_SET_TENSOR_HASH: { rpc_msg_set_tensor_hash_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; if (!server.set_tensor_hash(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; - if (!recv_msg(sockfd, &request,sizeof(request))) { + if (!recv_msg(sock, &request,sizeof(request))) { return; } if (!server.init_tensor(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } std::vector response; if (!server.get_tensor(request, response)) { return; } - if (!send_msg(sockfd, response.data(), response.size())) { + if (!send_msg(sock, response.data(), response.size())) { return; } break; } case RPC_CMD_COPY_TENSOR: { rpc_msg_copy_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_copy_tensor_rsp response; if (!server.copy_tensor(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GRAPH_COMPUTE: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.graph_compute(input)) { @@ -2348,7 +1678,7 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GRAPH_RECOMPUTE: { rpc_msg_graph_recompute_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.graph_recompute(request)) { @@ -2358,14 +1688,14 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_device_memory_rsp response; if (!server.get_device_memory(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; @@ -2424,36 +1754,28 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir #else printf(" transport : TCP\n"); #endif // GGML_RPC_RDMA -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - fprintf(stderr, "WSAStartup failed: %d\n", res); - return; - } + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; } -#endif - auto server_socket = create_server_socket(host.c_str(), port); + auto server_socket = socket_t::create_server(host.c_str(), port); if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = server_socket->accept(); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket.get()); + rpc_serve_client(backends, cache_dir, client_socket); printf("Client connection closed\n"); fflush(stdout); } -#ifdef _WIN32 - WSACleanup(); -#endif + rpc_transport_shutdown(); for (auto backend : backends) { ggml_backend_free(backend); } diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 00000000000..a728152421f --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 00000000000..73b85cc530a --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); From 171f037fbaef10c7901018a2be91e85764d581c2 Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Sun, 19 Apr 2026 02:25:05 -0500 Subject: [PATCH 156/249] cmake: remove CMP0194 policy to restore MSVC builds (llama/21934) #21630 added the CMP0194 NEW policy to silence a CMake warning, but on Windows runners it caused CMake to prefer the MinGW toolchain for ASM and broke MSVC builds. Reverting only that policy block restores the previous working behavior. The CMake 4.1+ warning comes back, but that is cosmetic and does not break any platform. Reported-by: oobabooga Refs: #21630 Co-authored-by: texasich --- ggml/CMakeLists.txt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6b65ecd6e5c..a0eb9204eab 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,11 +1,5 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. -# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html -# MSVC is not a valid assembler for the ASM language. -# Set to NEW to avoid a warning on CMake 4.1+ with MSVC. -if (POLICY CMP0194) - cmake_policy(SET CMP0194 NEW) -endif() project("ggml" C CXX ASM) ### GGML Version From 671fd1527a4aeb1b186d54302d42a8f5451feb82 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 19 Apr 2026 15:18:35 +0530 Subject: [PATCH 157/249] ggml : reduce CPU overhead in meta backend (llama/22041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cache subgraph splits when cgraph is unchanged Skip per-call subgraph construction in ggml_backend_meta_graph_compute when the same ggml_cgraph is used consecutively. Assign uid to every sub-graph so that CUDA's fast uid check path hits too. * Address review comments * Keep the scope as is * Rename last_uid and last_n_subgraphs field. Remove last_max_tmp_size field. Refactor code. * Address review comments * Update ggml/src/ggml-backend-meta.cpp Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-backend-meta.cpp Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-backend-meta.cpp | 307 +++++++++++++++++---------------- 1 file changed, 160 insertions(+), 147 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 24f6bc0639d..39651adc1c1 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1456,6 +1456,8 @@ struct ggml_backend_meta_context { int max_nnodes = 0; size_t max_tmp_size = 0; size_t max_subgraphs = 0; + size_t n_subgraphs = 0; + uint64_t uid = 0; void * comm_ctx = nullptr; ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; @@ -1616,6 +1618,9 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, const size_t n_backends = ggml_backend_meta_n_backends(backend); ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. + const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); + bool max_nnodes_raised = false; if (cgraph->n_nodes > backend_ctx->max_nnodes) { for (size_t j = 0; j < n_backends; j++) { @@ -1625,173 +1630,181 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } backend_ctx->max_nnodes = cgraph->n_nodes; max_nnodes_raised = true; + assert(needs_rebuild); } - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. - // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. - bcj.nodes[i] = node; - continue; + + if (needs_rebuild) { + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); } - bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); - GGML_ASSERT(bcj.nodes[i]); } - } - size_t n_subgraphs = 0; - size_t max_tmp_size = 0; - { - // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: - auto get_i_delayed = [&](const int i) -> int { - int id = i; // i_delayed - int idr = i; // i_delayed return, last safe return value - - ggml_tensor * node = cgraph->nodes[id]; - int32_t n_used = ggml_node_get_use_count(cgraph, id); - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_ADD_ID && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && - ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_MUL && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - - if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; return idr; - } - for (int32_t k = 0; k < n_used; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || - next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || - ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; } - id++; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); } - id++; - } - for (int32_t k = 0; k < n_used - 2; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; } - id++; - } - idr = id; - return idr; - }; - - int i_start = 0; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - continue; - } - const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); - if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { - max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); - } - const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; - if (!new_subgraph) { - continue; + + i = get_i_delayed(i); + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; } + GGML_ASSERT(i_start == cgraph->n_nodes); + } - i = get_i_delayed(i); + backend_ctx->uid = cgraph->uid; + backend_ctx->n_subgraphs = n_subgraphs; + if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.cgraphs[n_subgraphs].offset = i_start; + bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); + const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step + const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); } - n_subgraphs++; - i_start = i + 1; } - GGML_ASSERT(i_start == cgraph->n_nodes); - } - if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); - } - backend_ctx->max_tmp_size = max_tmp_size; - } - - - if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { - backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); - const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); - const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step - const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step - const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); - const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); - const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); - ggml_init_params params = { - /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - backend_ctx->ctx.reset(ggml_init(params)); - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i = 0; i < n_subgraphs; i++) { - bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); - } - } - backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { - backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); - } - backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { - backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); - } - } - - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { - ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; - const size_t i_node_start = bcj.cgraphs[i_graph].offset; - const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; - cgraph_ij->n_nodes = i_node_stop - i_node_start; - ggml_hash_set_reset(&cgraph_ij->visited_hash_set); - for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { - ggml_tensor * node_ij = bcj.nodes[i_node]; - cgraph_ij->nodes[i_node - i_node_start] = node_ij; - const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); - const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); - cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + cgraph_ij->uid = ggml_graph_next_uid(); } } } @@ -1898,7 +1911,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, }; - for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); @@ -1907,7 +1920,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } } - if (n_backends > 1 && i < n_subgraphs - 1) { + if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { bool backend_allreduce_success = false; if (backend_ctx->comm_ctx) { std::vector nodes; From 945746b40c2fe7983b02f00fb0c653476334a7bc Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 19 Apr 2026 12:59:44 +0200 Subject: [PATCH 158/249] HIP: Remove unesscary NCCL_CHECK (llama/21914) --- ggml/src/ggml-cuda/vendors/hip.h | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 898fec31e36..52c38908e06 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -33,7 +33,6 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} -#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) From b8f57c9c50e389bd4e21e3ebc0c9db4506bf2e2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 19 Apr 2026 18:26:59 +0200 Subject: [PATCH 159/249] CUDA: refactor mma data loading for AMD (llama/22051) * CUDA: refactor mma data loading for AMD * fix CDNA MMQ occupancy * fix CDNA3 mma * fix RDNA3 compile --- ggml/src/ggml-cuda/common.cuh | 4 - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 57 ++----- ggml/src/ggml-cuda/mma.cuh | 245 +++++++++------------------ ggml/src/ggml-cuda/mmq.cuh | 201 ++-------------------- 4 files changed, 112 insertions(+), 395 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ddf50baf495..3aec1742ee1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) -#if defined(TURING_MMA_AVAILABLE) -#define LDMATRIX_TRANS_AVAILABLE -#endif // defined(TURING_MMA_AVAILABLE) - static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b613ae61fb8..e185449d491 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); @@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = warp_size >> n; - const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { @@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } -#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - T_A_VKQ A_identity; - make_identity_mat(A_identity); -#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { @@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. -#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); -#elif defined(AMD_MFMA_AVAILABLE) - // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. - // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. - // Load with transposed addressing: 4 strided half loads. - { - const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; - const half * xs0_h = (const half *) xs0; - const int stride_h = stride_tile_V * 2; // stride in half units - half * A_h = (half *) A.x; -#pragma unroll - for (int l = 0; l < 4; ++l) { - A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; - } - } -#else - // TODO: Try to transpose tile_V when loading gmem to smem. - // Use mma to transpose T_A_VKQ for RDNA. - T_A_VKQ A_trans; - load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - mma(A, A_trans, A_identity); -#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c91dd2d9ad6..b0f674635f1 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -86,17 +86,12 @@ namespace ggml_cuda_mma { // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template @@ -113,7 +108,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +116,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +133,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +148,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -283,7 +277,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -407,7 +401,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template @@ -701,57 +695,12 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) - const int row = t.get_i(0); - const int left_right = t.get_j(0) / 4; - const int up_down = row / 8; - const int idx = row % 8; - reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; -#else - GGML_UNUSED_VARS(t); - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } - template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -764,26 +713,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -796,19 +756,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -827,23 +794,30 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma { using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) - #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma { using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) #else @@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 28b662df925..b1a319de9be 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -104,7 +104,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; From 931cf2f3a81af3f347cc01bed686d374c851a2e4 Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Mon, 20 Apr 2026 01:39:45 -0400 Subject: [PATCH 160/249] Fix reorder MMVQ assert on unaligned vocab sizes (llama/22035) * [SYCL] Fix reorder MMVQ assert on unaligned vocab sizes The reorder mul_mat_vec_q dispatchers for Q4_0, Q8_0, Q4_K, and Q6_K asserted that block_num_y was a multiple of 16 subgroups. Models with a vocab size not divisible by 16 (for example HY-MT at 120818) aborted on model load when the output projection tripped the assert. I replaced the assert with padding: block_num_y now rounds up to a whole number of subgroup-sized workgroups. The kernel already has the row bounds check (`if (row >= nrows) return;`) so the extra padded threads early-exit cleanly. Row values are uniform across a subgroup so the collective reduce stays safe. For aligned vocab sizes the padded block_num_y equals the old value, so the kernel launch is identical and there is no regression. Thanks to @arthw for flagging the relationship to #21527. Fixes #22020. AI assisted coding, tested on Intel B70 hardware. * sycl: use WARP_SIZE for num_subgroups in reorder MMVQ launches Replaces the hardcoded 16 with WARP_SIZE in the four reorder_mul_mat_vec launch helpers (Q4_0, Q8_0, Q4_K, Q6_K). Compile-time no-op on the Intel target where WARP_SIZE is 16, but makes the relationship to subgroup size explicit. Per review by @NeoZhangJianyu on #22035. Assisted by Claude. --- ggml/src/ggml-sycl/mmvq.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index af22b98dddb..3a4577ecbbc 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -537,9 +537,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -682,9 +682,9 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -798,9 +798,9 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -842,9 +842,9 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); From 5f21fdcbb9400721732c161e4b04174e5d8625db Mon Sep 17 00:00:00 2001 From: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Mon, 20 Apr 2026 07:37:17 -0700 Subject: [PATCH 161/249] ggml-webgpu: updated matrix-vector multiplication (llama/21738) * merged properly, but slow q3_k and q5_k with u32 indexing * Start on new mat-vec * New format float paths working * Working q4_0 * Work on remaining legacy q-types * port k-quants to new matvec * remove old shader * Remove old constants, format * remove accidental file --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 34 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 28 +- .../wgsl-shaders/common_decls.tmpl | 7 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 1102 +++++++++++------ 4 files changed, 788 insertions(+), 383 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 7d9a4403fab..9d88f98050e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -44,18 +44,9 @@ // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide -// mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 - -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 - -// Requires 32 threads per output (wg_size/outputs_per_wg == 32) -#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 -// Requires at least two (and multiple of 2) k-quant blocks per tile -#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context { bool inplace = false; bool overlap = false; bool src_overlap = false; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -575,7 +567,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; - uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; }; @@ -1326,7 +1317,7 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; @@ -1337,7 +1328,8 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = "mul_mat_vec"; + std::string variant = "mul_mat_vec"; + const char * shader_src = wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { @@ -1386,25 +1378,25 @@ class ggml_webgpu_shader_lib { defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; if (key.src0_type >= GGML_TYPE_Q2_K) { - tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { - tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } - auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; - decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e7bda817a28..aa20a745e0a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -181,6 +181,7 @@ struct webgpu_dispatch_desc { struct webgpu_capabilities { wgpu::Limits limits; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: - use_fast = true; - break; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat - use_fast = !is_vec; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q2_K: + use_fast = true; break; default: break; @@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, } ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = src0; - shader_lib_ctx.src1 = src1; - shader_lib_ctx.dst = dst; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; @@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline - webgpu_pipeline gather_pipeline, main_pipeline; + webgpu_pipeline gather_pipeline; + webgpu_pipeline main_pipeline; std::vector dispatches; @@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + ctx->webgpu_global_ctx->capabilities.supports_subgroups = + ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). @@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #ifndef __EMSCRIPTEN__ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } #endif + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 62fe72ee3b1..14c045b0ba6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 { return (word >> shift) & 0xFFFFu; } +// Always reads the 4-byte-aligned word containing byte_offset. +// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u. +// this is used in k-quants for better performance +fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 { + return src0[(byte_offset & ~3u) / 4u]; +} + fn load_u32_at_src0(byte_offset: u32) -> u32 { let word_idx = byte_offset / 4u; let shift = (byte_offset & 0x3u) * 8u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 9f7b3e32eca..97c9f6d7a09 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,465 +1,865 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif enable f16; #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 -#ifdef VEC +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif -#define VEC_SIZE 4 -#define DST_TYPE vec4 +#ifdef VEC +#define VEC_SIZE 4u #define SRC0_TYPE vec4 #define SRC1_TYPE vec4 fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(dot(SRC1_TYPE(src0_val), src1_val)); } - -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} #endif #ifdef SCALAR - -#define VEC_SIZE 1 -#define DST_TYPE f32 +#define VEC_SIZE 1u #define SRC0_TYPE SRC0_INNER_TYPE #define SRC1_TYPE SRC1_INNER_TYPE fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; +@group(0) @binding(2) var dst: array; + +@group(0) @binding(3) var params: MulMatParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; } + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 #endif +) { + let thread_id = local_id.x; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + + var acc: array; #ifdef MUL_ACC_FLOAT -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { - let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; - let b = shared_vector[i / VEC_SIZE]; - local_sum += inner_dot(a, b); + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 18u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 20u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 22u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let qh_packed = load_u32_at_src0(block_byte_base + 2u); - - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at_src0(q_byte_offset); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } - } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 24u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at_src0(q_byte_offset); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } - } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 34u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 36u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = load_f16_at_src0(block_byte_base + 2u); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + f32(m); - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif -#ifdef MUL_ACC_Q6_K - -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } +#endif -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 2u; - let ix = tig % 2u; - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 - let nb = tile_size / BLOCK_SIZE; - let k_block_start = k_outer / BLOCK_SIZE; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - // Aligned scale byte position (is can be odd) - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; - var local_sum = 0.0; + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; - let d = f32(load_f16_at_src0(bbase + 208u)); + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; - let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); - let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); - let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } - var sums = vec4(0.0, 0.0, 0.0, 0.0); + let num_blocks = params.k / BLOCK_SIZE; - for (var l = 0u; l < 4u; l++) { - let y_base = i * BLOCK_SIZE + y_offset + l; - let yl0 = f32(shared_vector[y_base]); - let yl1 = f32(shared_vector[y_base + 32u]); - let yl2 = f32(shared_vector[y_base + 64u]); - let yl3 = f32(shared_vector[y_base + 96u]); - - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += yl0 * dq0; - sums[1] += yl1 * dq1; - sums[2] += yl2 * dq2; - sums[3] += yl3 * dq3; + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); } - local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } } - - return local_sum; -} #endif -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } -// Shared memory for collaborative loading and reduction -var shared_vector: array; // Cache vector tile -var partial_sums: array; // For reduction + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } +#endif -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let thread_id = local_id.x; +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } - // Handle batch dimensions - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } } +#endif - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; - var local_sum = 0.0; + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } +#endif - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { - shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } + } - workgroupBarrier(); + workgroupBarrier(); - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif - workgroupBarrier(); +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; } - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset: u32 = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } } - offset = offset / 2; + workgroupBarrier(); + stride = stride / 2; } - // Store back to global memory - if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { - dst[dst_idx / VEC_SIZE] = store_val(group_base); + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } } +#endif } From 2b9fb0be770f188e9d6b506403e3f3606f8a66dc Mon Sep 17 00:00:00 2001 From: pl752 Date: Mon, 20 Apr 2026 21:02:54 +0500 Subject: [PATCH 162/249] ggml-cpu: Optimized x86 and generic cpu q1_0 dot (follow up) (llama/21636) * Implemented optimized q1_0 dot for x86 and generic * Removed redundant helper definition * Removed two redundant instructions from AVX q1_0 dot * Fixed inconsistency with fp16 conversion for generic q1_0 dot and deduplicated generic fallback * Style cleanup around AVX q1_0 dot * Replaced explicitly unrolled blocks with inner for loop for q1_0 * Replaced scalar ARM q1_0 impl with new generic one --- ggml/src/ggml-cpu/arch-fallback.h | 1 - ggml/src/ggml-cpu/arch/arm/quants.c | 30 +----- ggml/src/ggml-cpu/arch/x86/quants.c | 158 ++++++++++++++++++++++++++++ ggml/src/ggml-cpu/quants.c | 26 +++-- 4 files changed, 179 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c589a213e9d..595ded09f03 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -83,7 +83,6 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 64d811fafe7..fe621332970 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -151,8 +151,6 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const block_q1_0 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; - float sumf = 0.0f; - #if defined(__ARM_NEON) float32x4_t sumv = vdupq_n_f32(0.0f); @@ -212,31 +210,13 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi } } - sumf = vaddvq_f32(sumv); + *s = vaddvq_f32(sumv); #else - // Scalar fallback - for (int i = 0; i < nb; i++) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - - // Process 4 Q8_0 blocks - for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - - int sumi = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi += xi * y[i*4 + k].qs[j]; - } - sumf += d0 * d1 * sumi; - } - } + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif - - *s = sumf; } diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633d..0a3e071e57c 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const } #endif #elif defined(__SSSE3__) +static inline __m128i bytes_from_bits_16(const uint8_t * x) { + uint16_t x16; + memcpy(&x16, x, sizeof(uint16_t)); + + const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + __m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask); + const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe); + bytes = _mm_or_si128(bytes, bit_mask); + + return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1)); +} + // horizontally add 4x4 floats static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { __m128 res_0 =_mm_hadd_ps(a, b); @@ -540,6 +552,152 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block; + { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32)); + } + for (int K = 1; K < 4; ++K) { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); + } + acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block; + { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[0]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), q); + } + for(int K = 1; K < 4; ++K) { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); + } +#undef Q1_AVX_BLOCK + + acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block)); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d)); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + +#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \ + { \ + const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \ + const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \ + const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \ + (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \ + } + Q1_SSSE3_BLOCK(0, 0, acc_0) + Q1_SSSE3_BLOCK(4, 1, acc_1) + Q1_SSSE3_BLOCK(8, 2, acc_2) + Q1_SSSE3_BLOCK(12, 3, acc_3) +#undef Q1_SSSE3_BLOCK + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index f66127c2290..e5f9a4083f9 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -137,22 +137,28 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c float sumf = 0.0; for (int i = 0; i < nb; i++) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); float sumi = 0.0f; for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); int sumi_block = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi_block += xi * y[i*4 + k].qs[j]; + const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); } sumi += d1 * sumi_block; From 6429023e5f48c37b03e4903bf2bab8ef875b244f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 Apr 2026 18:09:39 +0200 Subject: [PATCH 163/249] TP: fix 0-sized tensor slices, AllReduce fallback (llama/21808) * TP: fix 0-sized tensor slices, AllReduce fallback * fix layer structure <-> GPU count aliasing * add missing std::fill * fix CUDA device set, max ggml ctx size --- ggml/src/ggml-backend-meta.cpp | 218 +++++++++++++++++++++----------- ggml/src/ggml-cuda/ggml-cuda.cu | 13 +- 2 files changed, 154 insertions(+), 77 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 39651adc1c1..4bf90c6a98b 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1133,7 +1133,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { - GGML_ASSERT(ne[split_dim] != 0 && tensor->ne[split_dim] != 0); + GGML_ASSERT(tensor->ne[split_dim] != 0); const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); @@ -1170,6 +1170,28 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer simple_tensors.push_back(t_ij); } + + // If one of the sources has a zero-sized slice, disable the computation: + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) { + continue; + } + + const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) { + continue; + } + for (size_t j = 0; j < n_simple_bufs; j++) { + int64_t ne_sum = 0; + for (size_t s = 0; s < split_state_src.n_segments; s++) { + ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + } + if (ne_sum == 0) { + simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + buf_ctx->simple_tensors[tensor] = simple_tensors; return GGML_STATUS_SUCCESS; @@ -1442,17 +1464,20 @@ struct ggml_backend_meta_context { struct backend_config { ggml_backend_t backend; - std::vector cgraphs; - std::vector nodes; - ggml_backend_buffer_ptr buf; + std::vector cgraphs; + std::vector nodes; + std::vector bufs; - backend_config(ggml_backend_t backend) : backend(backend) {} + backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) { + bufs.resize(n_reduce_steps); + } }; std::string name; std::vector backend_configs; ggml_context_ptr ctx; std::vector cgraphs_aux; std::vector nodes_aux; + size_t n_reduce_steps; int max_nnodes = 0; size_t max_tmp_size = 0; size_t max_subgraphs = 0; @@ -1464,6 +1489,7 @@ struct ggml_backend_meta_context { ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + n_reduce_steps = std::ceil(std::log2(n_devs)); name = "Meta("; std::vector simple_backends; backend_configs.reserve(n_devs); @@ -1475,7 +1501,7 @@ struct ggml_backend_meta_context { } name += ggml_backend_dev_name(simple_dev); simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); - backend_configs.emplace_back(simple_backends.back()); + backend_configs.emplace_back(simple_backends.back(), n_reduce_steps); } name += ")"; @@ -1505,10 +1531,6 @@ struct ggml_backend_meta_context { ggml_backend_free(bc.backend); } } - - size_t n_reduce_steps() const { - return std::ceil(std::log2(backend_configs.size())); - } }; static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { @@ -1754,16 +1776,17 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) { + bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } } backend_ctx->max_tmp_size = max_tmp_size; } if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); - const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); - const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step - const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device + const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); @@ -1812,11 +1835,6 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, size_t iga = 0; // i graph aux size_t ina = 0; // i node aux - // FIXME usage_counts - auto get_cgraph_aux = [&]() -> ggml_cgraph * { - ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; - return ret; - }; auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; memset(ret, 0, sizeof(ggml_tensor)); @@ -1828,75 +1846,110 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } return ret; }; + auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf]; + if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) { + buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size)); + } + tensor->buffer = buf_ptr.get(); + tensor->data = ggml_backend_buffer_get_base(buf_ptr.get()); + }; + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: auto allreduce_fallback = [&](size_t i) -> ggml_status { std::vector step_cgraphs(n_backends, nullptr); - for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) { + // Zero out nodes that were disabled due to having a zero-sized slice: + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1]; + if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { + continue; + } + ggml_tensor * node_zero = get_node_aux(node); + node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN + node_zero->src[0] = node; + ggml_set_op_params_f32(node_zero, 0, 0.0f); + node_zero->data = node->data; + node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; + + step_cgraphs[j] = get_cgraph_aux(); + step_cgraphs[j]->nodes[0] = node_zero; + step_cgraphs[j]->n_nodes = 1; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) { + assert(step_cgraphs[j_dst] == nullptr); + auto & bcj_src = backend_ctx->backend_configs[j_src]; + auto & bcj_dst = backend_ctx->backend_configs[j_dst]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node_src)); + GGML_ASSERT(ggml_is_contiguous(node_dst)); + + ggml_tensor * node_tmp = get_node_aux(node_dst); + set_tmp_data(node_tmp, j_dst, i_buf); + + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp); + + ggml_tensor * node_red = get_node_aux(node_dst); + node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src; + node_red->view_offs = node_dst->view_offs; + node_red->op = GGML_OP_ADD; + node_red->src[0] = node_dst; + node_red->src[1] = node_tmp; + node_red->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red); + + ggml_cgraph * cgraph_aux = get_cgraph_aux(); + cgraph_aux->nodes[0] = node_red; + cgraph_aux->n_nodes = 1; + step_cgraphs[j_dst] = cgraph_aux; + }; + + size_t offset_j = n_backends/2; + while ((offset_j & (offset_j - 1)) != 0) { + offset_j--; + } + const size_t offset_j_max = offset_j; + size_t i_buf = 0; + + // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction: + for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) { + const size_t j_dst = j_src - 2*offset_j_max; + push_data(j_src, j_dst, i_buf); + const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + i_buf = 1; + } + + // Butterfly reduction: + for (; offset_j >= 1; offset_j /= 2) { std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); - for (size_t j = 0; j < n_backends; j++) { + for (size_t j = 0; j < 2*offset_j_max; j++) { const size_t j_other = j ^ offset_j; - if (j_other > j) { + if (j_other >= n_backends) { continue; } - - auto & bcj1 = backend_ctx->backend_configs[j]; - auto & bcj2 = backend_ctx->backend_configs[j_other]; - - ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main->nodes[bcj1.cgraphs[i].cgraph_main->n_nodes - 1]; - ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main->nodes[bcj2.cgraphs[i].cgraph_main->n_nodes - 1]; - GGML_ASSERT(ggml_is_contiguous(node1)); - GGML_ASSERT(ggml_is_contiguous(node2)); - - // Tmp tensors to receive P2P copies - ggml_tensor * node_tmp_1 = get_node_aux(node1); - node_tmp_1->buffer = bcj1.buf.get(); - node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.buf.get()); - - ggml_tensor * node_tmp_2 = get_node_aux(node2); - node_tmp_2->buffer = bcj2.buf.get(); - node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.buf.get()); - - // 2 P2P copies: exchange full buffers - ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2); - ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1); - - // Local ADD: node1 += tmp1 (in-place via view) - ggml_tensor * node_red_1 = get_node_aux(node1); - node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; - node_red_1->view_offs = node1->view_offs; - node_red_1->op = GGML_OP_ADD; - node_red_1->src[0] = node1; - node_red_1->src[1] = node_tmp_1; - node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; - ggml_backend_view_init(node_red_1); - - // Local ADD: node2 += tmp2 (in-place via view) - ggml_tensor * node_red_2 = get_node_aux(node2); - node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; - node_red_2->view_offs = node2->view_offs; - node_red_2->op = GGML_OP_ADD; - node_red_2->src[0] = node2; - node_red_2->src[1] = node_tmp_2; - node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; - ggml_backend_view_init(node_red_2); - - // Build 1-node cgraphs for the ADD ops - ggml_cgraph * cgraph_aux_1 = get_cgraph_aux(); - cgraph_aux_1->nodes[0] = node_red_1; - cgraph_aux_1->n_nodes = 1; - step_cgraphs[j] = cgraph_aux_1; - - ggml_cgraph * cgraph_aux_2 = get_cgraph_aux(); - cgraph_aux_2->nodes[0] = node_red_2; - cgraph_aux_2->n_nodes = 1; - step_cgraphs[j_other] = cgraph_aux_2; + push_data(j, j_other, i_buf); } - // Execute local ADDs for this step - for (size_t j = 0; j < n_backends; j++) { + for (size_t j = 0; j < 2*offset_j_max; j++) { if (step_cgraphs[j] == nullptr) { continue; } @@ -1906,7 +1959,20 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, return status; } } + i_buf++; } + assert(i_buf == backend_ctx->n_reduce_steps); + + // If n_backends is not a power of 2, copy back the reduced tensors to the excess: + for (size_t j = 2*offset_j_max; j < n_backends; j++) { + auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max]; + auto & bcj_dst = backend_ctx->backend_configs[j]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst); + } + return GGML_STATUS_SUCCESS; }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de579d2ed50..ecd12b80dfe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1203,6 +1203,13 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg // For small tensors, simply reduce them as FP32. // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + for (size_t i = 0; i < n_backends; ++i) { + if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + ggml_cuda_set_device(cuda_ctx->device); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream())); + } + } NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; @@ -1224,7 +1231,11 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg tmp[i].alloc(ne); ggml_cuda_set_device(cuda_ctx->device); - to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) { + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + } else { + CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream())); + } CUDA_CHECK(cudaGetLastError()); } From 239c5c86c30d36249e3479459914c4eb24958f19 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Mon, 20 Apr 2026 21:55:39 +0530 Subject: [PATCH 164/249] Tensor-parallel: Fix delayed AllReduce on Gemma-4 MoE (llama/22129) * Fix delayed AllReduce on Gemma-4 MoE Skip forward past nodes that don't consume the current one, and allow a chain of MULs. * Check for all sources before skipping nodes * Address review comments --- ggml/src/ggml-backend-meta.cpp | 42 ++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 4bf90c6a98b..6d22f3421b1 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1683,6 +1683,36 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, ggml_tensor * node = cgraph->nodes[id]; int32_t n_used = ggml_node_get_use_count(cgraph, id); + + // Skip MIRRORED nodes that don't consume node + auto skip_unrelated = [&]() { + while (id + 1 < cgraph->n_nodes) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + break; + } + bool safe = true; + for (int s = 0; s < GGML_MAX_SRC; s++) { + if (next->src[s] == nullptr) { + continue; + } + if (next->src[s] == node) { + safe = false; + break; + } + if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + safe = false; + break; + } + } + if (!safe) { + break; + } + id++; + } + }; + + skip_unrelated(); if (id + 1 >= cgraph->n_nodes) { return idr; } @@ -1697,10 +1727,12 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, n_used = ggml_node_get_use_count(cgraph, id); } } - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { + // Chain of MULs with MIRRORED src[1] + while (true) { + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } ggml_tensor * next = cgraph->nodes[id+1]; if (next->op == GGML_OP_MUL && next->src[0] == node && ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { @@ -1708,6 +1740,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, id++; idr = id; n_used = ggml_node_get_use_count(cgraph, id); + } else { + break; } } From b13deaabaec1d52fc228195e56f96f1d83b7d2c0 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Tue, 21 Apr 2026 05:30:38 +0800 Subject: [PATCH 165/249] ggml-cuda: flush legacy pool on OOM and retry (llama/22155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: flush legacy pool on OOM and retry Signed-off-by: 梁厚宏 <2695316095@qq.com> * Address review comments: add explicit sync, update destructor, clean up MUSA macros Signed-off-by: 梁厚宏 <2695316095@qq.com> --------- Signed-off-by: 梁厚宏 <2695316095@qq.com> --- ggml/src/ggml-cuda/ggml-cuda.cu | 23 +++++++++++++++++++++-- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ecd12b80dfe..185956317e0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -368,15 +368,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { + clear_pool(); + GGML_ASSERT(pool_size == 0); + } + + void clear_pool() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; } } - GGML_ASSERT(pool_size == 0); } void * alloc(size_t size, size_t * actual_size) override { @@ -421,7 +427,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 52c38908e06..78ca364d38f 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -58,6 +58,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4b..8aa056e9174 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -42,6 +42,7 @@ #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags From e7cffdbd0bc1ea97a605ab361907ef771b993bea Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Apr 2026 11:02:56 +0300 Subject: [PATCH 166/249] ggml : bump version to 0.10.0 (ggml/1463) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a0eb9204eab..2effd587b41 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 11) +set(GGML_VERSION_MINOR 10) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 85bbc822092a88d4feb0b2f8ddad0bb2de04488e Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 21 Apr 2026 11:01:56 +0200 Subject: [PATCH 167/249] vulkan: Support F16 OP_FILL (llama/22177) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +++++++- .../src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 702a249d754..d4acee8b1df 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -792,6 +792,7 @@ struct vk_device_struct { vk_pipeline pipeline_arange_f32; vk_pipeline pipeline_fill_f32; + vk_pipeline pipeline_fill_f16; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -4577,6 +4578,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -9844,6 +9846,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_fill_f32; } + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_fill_f16; + } return nullptr; default: return nullptr; @@ -15713,8 +15718,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: - case GGML_OP_FILL: return op->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; case GGML_OP_SCALE: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 54b9b327333..ff836615330 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -889,6 +889,7 @@ void process_shaders() { string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f16", "fill.comp", {{"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); From 150cef5a5f5c5272444eafbf090083476c8b1ccf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Apr 2026 17:24:55 +0300 Subject: [PATCH 168/249] metal : workaround macOS GPU interactivity watchdog (llama/22216) --- ggml/src/ggml-metal/ggml-metal.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 4dbf8e6fea9..6a836e45908 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -918,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { static std::vector devs; if (!initialized) { + // workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible + // ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703 + setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true); + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); for (int i = 0; i < g_devices; ++i) { From 3a73f9cf0b3dc2a221000fd865545c783d3e978d Mon Sep 17 00:00:00 2001 From: Zijun Yu Date: Tue, 21 Apr 2026 23:58:34 +0800 Subject: [PATCH 169/249] openvino: driver setup, CI split, thread safety, and NPU optimizations (llama/21944) * Thread safety per request only * Fix ROPE yarn case * Fix sticky stateful config * Use i4/i8 directly for symmetric quant * Use weightless caching * Add WeightlessCacheAttribute to reduce NPU memory usage * Gelu tanh support (llama/125) * Imrope support (llama/126) * fix(openvino): explicit ov::Tensor frees in ggml_backend_openvino_free * add GPU,NPU support in OV Dockerfile * add build-openvino.yml ci * Fix sticky stateful config * add concurrency to ov-gpu ci runs. Move OV CI to build-openvino.yml * fix thread-safety of shared runtime context * rope type abstraction for frontend translations * fix editorconfig --------- Co-authored-by: Mustafa Cavus Co-authored-by: Dan Hoffman Co-authored-by: Ravi Panchumarthy --- ggml/src/ggml-openvino/ggml-decoder.cpp | 20 +- .../src/ggml-openvino/ggml-openvino-extra.cpp | 29 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 42 +- ggml/src/ggml-openvino/ggml-quants.cpp | 456 ++++++++++-------- ggml/src/ggml-openvino/openvino/op/rope.cpp | 40 +- .../ggml-openvino/openvino/op/unary_gelu.cpp | 25 + ggml/src/ggml-openvino/openvino/op_table.cpp | 1 + ggml/src/ggml-openvino/openvino/op_table.h | 1 + .../openvino/pass/eliminate_zp.cpp | 123 ----- .../openvino/pass/eliminate_zp.h | 17 - .../rt_info/weightless_caching_attributes.hpp | 41 ++ .../openvino/translate_session.cpp | 30 +- ggml/src/ggml-openvino/openvino/utils.cpp | 103 ++-- ggml/src/ggml-openvino/openvino/utils.h | 1 + ggml/src/ggml-openvino/utils.cpp | 145 ++++-- ggml/src/ggml-openvino/utils.h | 26 +- 16 files changed, 646 insertions(+), 454 deletions(-) create mode 100644 ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp delete mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp delete mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h create mode 100644 ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0938d2273e9..5095e799849 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -207,8 +206,22 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { break; } case GGML_OP_ROPE: { + const int mode = node->op_params[2]; + switch (mode) { + case GGML_ROPE_TYPE_NEOX: { + op_case = 0x00010000; + break; + } + case GGML_ROPE_TYPE_IMROPE: { + op_case = 0x00020000; + break; + } + default: + op_case = 0x00000000; + break; + } if (node->src[0]->op == GGML_OP_VIEW) { - op_case = 2; + op_case = (op_case | 0x00000002); } break; } @@ -573,9 +586,6 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const } std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { - static std::mutex weights_mutex; - std::lock_guard lock(weights_mutex); - std::map> model_weights; auto * nodes = cgraph->nodes; auto n_nodes = cgraph->n_nodes; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index cc3cb4583cd..4140136aca2 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include ov::Core & ov_singleton_core() { @@ -42,11 +43,13 @@ void ggml_openvino_device_config::init() { {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, }; - if (cache_dir) { + if (cache_dir && strlen(cache_dir) > 0) { compile_config["NPUW_CACHE_DIR"] = cache_dir; + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } - } else if (cache_dir) { - ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } else if (cache_dir && strlen(cache_dir) > 0) { + compile_config.insert(ov::cache_dir(cache_dir)); + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } // Initialize remote context with queue sharing for GPU @@ -259,10 +262,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - // For symmetric quantization, we only need one zp value (not one per block) - // Zero points are stored in U4 or U8 format matching the weight type - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; @@ -313,10 +318,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Scales: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - // Zero points: U4 or U8 matching weight type - // For symmetric quantization, we only need one zp value (not one per block) - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } // Layout in buffer: [weights | scales | zp] with alignment layout.weights_offset = 0; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0c8d3508e87..4f3ebf2536b 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -145,13 +145,18 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer return ctx->data; } +static bool is_stateful_enabled() { + static const auto * stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION"); + return stateful && *stateful != '\0' && strcmp(stateful, "0") != 0; +} + static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && - !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + !is_stateful_enabled()) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -600,6 +605,14 @@ bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { static void ggml_backend_openvino_free(ggml_backend_t backend) { ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + + if (ctx->runtime_context) { + auto r_ctx = std::static_pointer_cast(ctx->runtime_context); + if (--r_ctx->backend_count == 0) { + r_ctx->clear_caches(); + } + } + delete ctx; delete backend; } @@ -644,7 +657,12 @@ static ggml_guid_t ggml_backend_openvino_guid(void) { } static std::shared_ptr get_ov_runtime_context_ptr() { - static std::shared_ptr r_ctx = std::make_shared(); + static std::shared_ptr r_ctx = [] { + auto ctx = std::make_shared(); + ctx->device = ggml_openvino_get_device_name(); + ctx->stateful = is_stateful_enabled() && !ggml_openvino_is_npu(); + return ctx; + }(); return r_ctx; } @@ -669,8 +687,7 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { } std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); - r_ctx->device = ggml_openvino_get_device_name(); - r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu(); + r_ctx->backend_count++; ggml_backend_t openvino_backend = new ggml_backend{ /* .guid = */ ggml_backend_openvino_guid(), @@ -883,7 +900,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { const int32_t * op_params = op->op_params; const int n_dims = op_params[1]; const int mode = op_params[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); return true; } @@ -896,14 +913,6 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } - float freq_scale; - float ext_factor; - memcpy(&freq_scale, op_params + 6, sizeof(float)); - memcpy(&ext_factor, op_params + 7, sizeof(float)); - if (ext_factor != 0.0f) { - // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); - return true; - } if (op->src[0]->op == GGML_OP_VIEW) { if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { // GGML_LOG_WARN( @@ -913,6 +922,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { return true; } } + if (mode == GGML_ROPE_TYPE_IMROPE && + (op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 || + ((const float *) op_params)[8] != 1)) { + // GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n"); + return true; + } break; } default: @@ -942,6 +957,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; static const std::set supported_unary_ops{ + GGML_UNARY_OP_GELU, GGML_UNARY_OP_SILU, }; static const std::set supported_glu_ops{ diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index dbf38646ddd..57d66df4f01 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -46,6 +46,7 @@ void unpack_32_4(const uint8_t * data, uint8_t * dst) { // Extracts (weight, scales, zp) from Q4_0 tensors. // Data layout is: |16 bit scale|32 x 4bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i4 (value - 8). void extract_q4_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -55,28 +56,32 @@ void extract_q4_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - // For asymmetric quantization, compute per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); // Pack two 4-bit zero points per byte if (i % 2 == 0) { zp[i / 2] = 8; // Lower nibble } else { zp[i / 2] |= (8 << 4); // Upper nibble } - } - unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); - }); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); + } else { + // Symmetric: unpack as u4 then convert to i4 by subtracting 8 (XOR each nibble) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + // Convert u4 to i4: subtract 8 from each nibble. XOR 0x88 flips each nibble by 8. + for (int j = 0; j < 16; ++j) { + weights[i * 16 + j] ^= 0x88; + } + }); + } } // Extracts (weight, scales, zp) from Q4_1 tensors. @@ -123,6 +128,7 @@ void extract_q4_1_data(const ggml_tensor * tensor, // Extracts (weight, scales, zp) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i8 directly. void extract_q8_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -133,29 +139,30 @@ void extract_q8_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); zp[i] = 128; - } - for (size_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } - }); + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; + x ^= 1 << 7; // Convert int8 to uint8 by flipping sign bit + weights[i * weights_per_block + j] = x; + } + }); + } else { + // Symmetric: store original int8 values directly (no unsigned bias) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // Copy int8 weights as-is (the tensor element type is i8) + memcpy(weights + i * weights_per_block, block_data + 2, weights_per_block); + }); + } } void unpack_256_4(const uint8_t * data, uint8_t * dst) { @@ -256,44 +263,62 @@ void extract_q6_k_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q6_K, zero point is always 32 - if (is_scalar_zp) { - zp[0] = 32; - } - - ov::parallel_for(n_super_block, [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - float scale_factor = - static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); // (128+64+16)/2 + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (size_t j = 0; j < 16; j++) { - scales[j + i * 16] = - ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); zp[j + i * 16] = 32; } - } - - uint8_t * ql = block_data; - uint8_t * qh = block_data + 128; - - for (int64_t j = 0; j < 32; ++j) { - weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); - weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); - weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); - weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); - weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); - weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); - weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); - weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); - } - }); + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); + } else { + // Symmetric: subtract 32 from each weight to store as signed i8 + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + auto * signed_weights = reinterpret_cast(weights); + for (int64_t j = 0; j < 32; ++j) { + signed_weights[i * 256 + j] = static_cast((ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 32] = + static_cast((ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 64] = static_cast((ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 96] = + static_cast((ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 128] = + static_cast((ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 160] = + static_cast((ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 192] = + static_cast((ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 224] = + static_cast((ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4)) - 32; + } + }); + } } static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { @@ -389,11 +414,10 @@ ov::Output make_int8_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i8); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias auto scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; @@ -403,37 +427,48 @@ ov::Output make_int8_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - // Create graph nodes - auto weights_node = std::make_shared(ov::element::u8, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_point = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_point, zp_value)) { - zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -452,11 +487,10 @@ ov::Output make_int4_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_weight_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i4); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias ov::Shape scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; @@ -467,36 +501,48 @@ ov::Output make_int4_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - auto weights_node = std::make_shared(ov::element::u4, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); auto scales_f16 = std::make_shared(scales); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_points_node = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_points_node, zp_value)) { - zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -699,24 +745,32 @@ OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, vo // Quantized path (normal extraction or quantized requant) // Create weight/scale/zp tensors - shared between both paths - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + // For symmetric quantization, use signed types (i4/i8) and no ZP tensor + ov::element::Type weight_type = layout.is_symmetric ? (layout.is_u4 ? ov::element::i4 : ov::element::i8) : + (layout.is_u4 ? ov::element::u4 : ov::element::u8); ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + if (!layout.is_symmetric) { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape, buf_base + layout.zp_offset); + } + // else: result.zp remains default-constructed (empty) for symmetric } else { result.weights = ov::Tensor(weight_type, node_shape); result.scales = ov::Tensor(ov::element::f16, scale_shape); - if (use_bias && !layout.is_symmetric) { - // bias only has effect for asymmetric quant - result.zp = ov::Tensor(ov::element::f16, zp_shape); - } else { - result.zp = ov::Tensor(weight_type, zp_shape); + if (!layout.is_symmetric) { + if (use_bias) { + result.zp = ov::Tensor(ov::element::f16, scale_shape); + } else { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape); + } } + // else: result.zp remains default-constructed (empty) for symmetric } if (layout.is_requant && layout.requant_type.has_value()) { @@ -741,59 +795,75 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - } - - const float d = max / -8; - - if (d == 0) { - scales[i] = ov::float16(1.0f); - // zp is already set to 8 for symmetric, or set per-block for asymmetric - if (!is_scalar_zp) { + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; } - memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); - continue; - } - - const float id = 1.0f / d; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float id = 1.0f / d; + scales[i] = ov::float16(d); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } } - - for (int j = 0; j < qk / 2; ++j) { - const float x0 = x[i * qk + 2 * j] * id; - const float x1 = x[i * qk + 2 * j + 1] * id; - const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); - weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } else { + // Symmetric: produce signed i4 values in [-8, 7] + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + // i4 value 0 packed: 0x00 + memset(weights + i * qk / 2, 0, qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + // Signed i4: range [-8, 7]. Quantize as round(x*id), then pack as 4-bit two's complement. + int8_t si0 = (int8_t) std::max(-8, std::min(7, (int) roundf(x0))); + int8_t si1 = (int8_t) std::max(-8, std::min(7, (int) roundf(x1))); + weights[i * qk / 2 + j] = (si0 & 0x0F) | ((si1 & 0x0F) << 4); + } } } } @@ -809,36 +879,42 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); } - } - - const float d = amax / 127.0f; - const float id = d ? 1.0f / d : 0.0f; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); zp[i] = 128; + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } } - - for (int j = 0; j < qk; ++j) { - const float x0 = x[i * qk + j] * id; - const int8_t xi0 = roundf(x0); - weights[i * qk + j] = (uint8_t) (xi0 + 128); + } else { + // Symmetric: store signed int8 values directly + auto * signed_weights = reinterpret_cast(weights); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + signed_weights[i * qk + j] = (int8_t) roundf(x0); + } } } } @@ -861,12 +937,8 @@ void quantize_q8_1(const float * x, for (int j = 0; j < qk; j++) { const float v = x[i * qk + j]; - if (v < min) { - min = v; - } - if (v > max) { - max = v; - } + min = std::min(v, min); + max = std::max(v, max); } const float d = (max - min) / ((1 << 8) - 1); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 26dc2d24f82..a8db9b38930 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -9,12 +9,17 @@ #include #include #include +#include +#include +#include #include #include #include +#include #include #include #include +#include #include #include @@ -33,6 +38,12 @@ OutputVector translate_rope(const NodeContext & context) { auto data_node = context.get_input(0).get_node_shared_ptr(); auto output_shape = context.get_output_shape().to_shape(); int32_t * op_params = context.get_output_op_params(); + const int mode = (op_case & 0xFFFF0000) >> 16; + op_case = (op_case & 0x0000FFFF); + + constexpr int TYPE_NORMAL = 0; + constexpr int TYPE_NEOX = 1; + constexpr int TYPE_IMROPE = 2; Output cos_theta_node; Output sin_theta_node; @@ -45,7 +56,7 @@ OutputVector translate_rope(const NodeContext & context) { if (context.get_input_size() == 3) { rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); } - auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == TYPE_IMROPE); sin_theta_node = sin_cos.first; cos_theta_node = sin_cos.second; } @@ -65,11 +76,7 @@ OutputVector translate_rope(const NodeContext & context) { } } - const int mode = op_params[2]; - constexpr int ROPE_TYPE_NORMAL = 0; - constexpr int ROPE_TYPE_NEOX = 2; - - if (mode == ROPE_TYPE_NORMAL) { + if (mode == TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); @@ -97,7 +104,7 @@ OutputVector translate_rope(const NodeContext & context) { auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); res = std::make_shared(stack, data_shape, false); - } else if (mode == ROPE_TYPE_NEOX) { + } else if (mode == TYPE_NEOX) { auto data_split = std::make_shared( data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); Output slice_data_node_0 = data_split->outputs()[0]; @@ -112,6 +119,25 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_1, cos_theta_node)); res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); + } else if (mode == TYPE_IMROPE) { + int64_t n_dims = data_node->get_shape()[3]; + auto cos_sin_shape = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1,-1,1,(n_dims >> 1)}); + auto cos_reshaped = std::make_shared(cos_theta_node, cos_sin_shape, true); + auto sin_reshaped = std::make_shared(sin_theta_node, cos_sin_shape, true); + + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + auto split_a = std::make_shared(data_node, split_axis, 2); + auto x0 = split_a->output(0); + auto x1 = split_a->output(1); + auto mul_a = std::make_shared(x0, cos_reshaped); + auto mul_b = std::make_shared(x1, sin_reshaped); + auto sub = std::make_shared(mul_a, mul_b); + + auto mul_c = std::make_shared(x0, sin_reshaped); + auto mul_d = std::make_shared(x1, cos_reshaped); + auto add = std::make_shared(mul_c, mul_d); + + res = std::make_shared(ov::OutputVector{sub, add}, 3); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp new file mode 100644 index 00000000000..d1e9efc33a5 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp @@ -0,0 +1,25 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_gelu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto res = std::make_shared(input); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index beadafe8103..1385539279c 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -31,6 +31,7 @@ std::unordered_map get_supported_ops() { {"GGML_OP_SOFT_MAX", op::translate_soft_max }, {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_GELU", op::translate_unary_gelu }, {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h index 37f763117aa..f546796d2ee 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.h +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -21,6 +21,7 @@ GGML_OP_CONVERTER(translate_rms_norm); GGML_OP_CONVERTER(translate_rope); GGML_OP_CONVERTER(translate_scale); GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_unary_gelu); GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp deleted file mode 100644 index ed2a3ab6d1b..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "eliminate_zp.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -EliminateZeroPoints::EliminateZeroPoints() { - // Find pattern: - // (Multiply Any(scale) - // (Subtract (Convert Constant(data))) - // (Convert Constant(zero_point))) - // where zero_point is a scalar - // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val - // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant - - auto m_data_constant = ov::pass::pattern::wrap_type(); - auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); - - auto m_zp_constant = ov::pass::pattern::wrap_type(); - auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); - - auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); - auto m_scale = ov::pass::pattern::any_input(); - auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); - - const auto callback = [=](ov::pass::pattern::Matcher & m) { - const auto & pattern_map = m.get_pattern_value_map(); - - auto multiply_node = - std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); - auto subtract_node = - std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); - auto data_constant = - std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); - auto zp_constant = - std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); - - if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { - return false; - } - - if (ov::shape_size(zp_constant->get_shape()) != 1) { - return false; - } - - auto data_type = data_constant->get_element_type(); - auto zp_data = zp_constant->cast_vector(); - - if (zp_data.empty()) { - return false; - } - - int zp_value = zp_data[0]; - - bool should_eliminate = false; - ov::element::Type target_type; - - if (data_type == ov::element::u4 && zp_value == 8) { - should_eliminate = true; - target_type = ov::element::i4; - } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { - should_eliminate = true; - target_type = ov::element::i8; - } - - if (!should_eliminate) { - return false; - } - - auto data_shape = data_constant->get_shape(); - size_t total_elements = ov::shape_size(data_shape); - - std::shared_ptr new_constant; - - // TODO improve performance - if (data_type == ov::element::u4) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } else if (data_type == ov::element::u8) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&, zp_value](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } - - auto new_convert = - std::make_shared(new_constant, subtract_node->get_output_element_type(0)); - ov::replace_node(subtract_node, new_convert); - - return true; - }; - - register_matcher( - std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), - callback); -} - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h deleted file mode 100644 index edd3cd718d9..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +++ /dev/null @@ -1,17 +0,0 @@ -#include "openvino/pass/matcher_pass.hpp" - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -class EliminateZeroPoints : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") - EliminateZeroPoints(); -}; - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp new file mode 100644 index 00000000000..f051891c481 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { + +/** + * @brief Holds weightless caching attributes of a single constant. + * + * WeightlessCacheAttribute class represents runtime info attribute that holds + * the values of original size of the constant in bytes and the binary offset of the + * constant's data in the weights file used by the weightless caching mechanism. It's + * not copyable in case the data was changed (the original node was replaced by a new + * one produced during the tranformation pipeline) - in that case weightless caching + * can't be used for that constant. + */ +class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { +public: + OPENVINO_RTTI("WeightlessCacheAttribute", "0", RuntimeAttribute) + + WeightlessCacheAttribute() = delete; + + WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) + : original_size(original_size), + bin_offset(bin_offset), + original_dtype(original_dtype) {} + + bool is_copyable() const override; + + size_t original_size; + size_t bin_offset; + ov::element::Type original_dtype; +}; + +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 23a1dea2496..0f68a1f5062 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -3,15 +3,16 @@ #include "ggml-openvino/openvino/node_context.h" #include "ggml-openvino/openvino/utils.h" #include "input_model.h" -#include "pass/eliminate_zp.h" #include "pass/mark_decompression_convert_constant_folding.h" #include "pass/squeeze_matmul.h" +#include "rt_info/weightless_caching_attributes.hpp" #include #include #include #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include namespace ov { namespace frontend { @@ -240,6 +240,31 @@ std::shared_ptr TranslateSession::translate_graph(const frontend::InputMo resulting_model = std::make_shared(results, used_params); apply_transformations(resulting_model); + + // Set WeightlessCacheAttribute on large constants to avoid unnecessary memory copies + // in the NPUW plugin. Without this attribute, NPUW's LazyTensor constructor + // (lazy_tensor.cpp, op::Const::Const) will memcpy every constant "in case export + // occurs", doubling memory usage per compile_model call. + // + // The bin_offset field serves as a unique key (not a real file offset) — this is + // the same convention the GPU plugin uses for non-IR models (see + // Plugin::set_weightless_cache_attributes in intel_gpu/src/plugin/plugin.cpp). + // Each constant must have a distinct bin_offset, otherwise GPU's weightless cache + // import will map multiple constants to the same data. + // + // Small constants (< 16 elements) are excluded since they may be introduced by + // optimization patterns and the overhead is negligible. + size_t offset = 0; + for (auto & node : resulting_model->get_ordered_ops()) { + if (auto cnst = ov::as_type_ptr(node); + cnst && cnst->get_byte_size() / cnst->get_element_type().size() >= 16) { + auto & rt_info = cnst->get_rt_info(); + if (rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()) == rt_info.end()) { + rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = + ov::WeightlessCacheAttribute(cnst->get_byte_size(), offset++, cnst->get_element_type()); + } + } + } return resulting_model; } @@ -257,7 +282,6 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptris_static()) { - manager.register_pass(); manager.register_pass(); } manager.run_passes(model); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index 65356a51b51..0baaf88e17a 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -87,8 +89,11 @@ ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], fl auto ramp_y = std::make_shared(std::make_shared(dim_ids, corr_low), denom); auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + // rope_yarn_ramp returns (1 - clamp(y)), so invert before scaling + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + auto ramp_inverted = std::make_shared(one, ramp_clamped); auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); - auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + auto ramp_mix = std::make_shared(ramp_inverted, ext_factor_node); return ramp_mix; } @@ -115,6 +120,7 @@ void ggml_rope_yarn_corr_dims(int n_dims, std::pair, ov::Output> make_sin_cos(int32_t * rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight, + bool imrope, bool stateful) { if (stateful) { inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); @@ -122,6 +128,13 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params auto pos_perm = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); inp_pos = std::make_shared(inp_pos, pos_perm); + } else if (imrope) { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1}); + inp_pos = std::make_shared(inp_pos, pos_shape, true); + auto pos_transpose_shape = + std::make_shared(ov::element::i64, ov::Shape{5}, std::vector{0, 1, 2, 4, 3}); + inp_pos = std::make_shared(inp_pos, pos_transpose_shape); } else { inp_pos = std::make_shared(inp_pos, ov::element::f32); auto pos_perm = @@ -136,6 +149,7 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params float beta_fast; float beta_slow; const int n_dims = rope_params[1]; + const size_t n_dims_half = n_dims >> 1; const int n_ctx_orig = rope_params[4]; memcpy(&freq_base, rope_params + 5, sizeof(float)); memcpy(&freq_scale, rope_params + 6, sizeof(float)); @@ -146,57 +160,74 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params const float theta_scale = powf(freq_base, -2.0f / n_dims); - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - std::vector factor(n_dims / 2); - factor[0] = 1.0f; - for (size_t i = 1; i < factor.size(); i++) { - factor[i] = theta_scale * factor[i - 1]; - } + std::vector factor(n_dims_half); Output freq_factors; - if (stateful) { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); - } else { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); - } - if (rope_freqs_weight) { - freq_factors = std::make_shared(freq_factors, rope_freqs_weight); - } - - auto theta_extrap = std::make_shared(freq_factors, inp_pos); - auto theta_interp = std::make_shared( - theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); Output theta; float mscale = attn_factor; - if (ext_factor == 0.0f) { - theta = theta_interp; + if (imrope) { + std::vector gather_indices(n_dims_half); + for (size_t j = 0; j < n_dims_half; j++) { + gather_indices[j] = j % 3; + factor[j] = std::pow(theta_scale, j); + } + auto gather_indices_const = + std::make_shared(ov::element::i64, ov::Shape{n_dims_half}, gather_indices); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4}); + inp_pos = std::make_shared(inp_pos, gather_indices_const, gather_axis); + auto factor_const = std::make_shared(ov::element::f32, ov::Shape{n_dims_half}, factor); + theta = std::make_shared(inp_pos, factor_const); } else { - auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); - Output one; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } if (stateful) { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); } else { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); } - auto one_minus_ramp = std::make_shared(one, ramp_mix); - theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), - std::make_shared(theta_extrap, ramp_mix)); - mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } } Output cos_theta = std::make_shared(theta); Output sin_theta = std::make_shared(theta); - auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + if (!imrope) { + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + } - cos_theta = std::make_shared(cos_theta, mscale_node); - sin_theta = std::make_shared(sin_theta, mscale_node); return std::make_pair(sin_theta, cos_theta); } diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h index 88dcad4c906..767dd4c53ea 100644 --- a/ggml/src/ggml-openvino/openvino/utils.h +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -67,6 +67,7 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std:: std::pair, ov::Output> make_sin_cos(int32_t* rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight = nullptr, + bool imrope = false, bool stateful = false); ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 1b553a0de00..998ef7c9eb4 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -81,8 +81,8 @@ ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { auto & core = ov_singleton_core(); const auto & config = ggml_openvino_get_compile_config(); - auto device = r_ctx->device; - bool stateful = r_ctx->stateful; + const auto & device = r_ctx->device; + const auto & stateful = r_ctx->stateful; static auto is_static = false; if (is_naive(cgraph)) { @@ -106,14 +106,26 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< int64_t infer_end_time; { - std::lock_guard lock(r_ctx->ov_compute_mutex); + std::shared_ptr entry; + ModelParams old_m_params; - auto it = r_ctx->decoder_cache.find(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); - cache_hit = it != r_ctx->decoder_cache.end(); - ModelParams old_m_params; if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_dynamically(m_params); } @@ -126,7 +138,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< ggml_decoder->update_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = r_ctx->infer_request_cache.at(key); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -170,7 +185,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -199,8 +217,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } compile_end_time = ggml_time_us(); infer_request = std::make_shared(compiled_model.create_infer_request()); - r_ctx->infer_request_cache[key] = infer_request; - r_ctx->decoder_cache[key] = ggml_decoder; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -210,8 +227,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< for (const auto & ov_output : model->get_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -224,8 +246,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names; + std::vector ov_output_names; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names = r_ctx->ov_input_names_cache[key]; + ov_output_names = r_ctx->ov_output_names_cache[key]; + } for (size_t i = 0; i < ov_input_names.size(); i++) { auto param_name = ov_input_names[i]; @@ -306,12 +333,26 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrdecoder_cache.find(key); - - cache_hit = it != r_ctx->decoder_cache.end(); + std::shared_ptr entry; ModelParams old_m_params; + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); + if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_statically(m_params); } @@ -325,14 +366,21 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrupdate_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = + is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + } decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); - r_ctx->infer_request_cache_prefill.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -372,16 +420,14 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer_request_cache_prefill[key] = - std::make_shared(compiled_model_prefill.create_infer_request()); - r_ctx->infer_request_cache[key] = - std::make_shared(compiled_model_decode.create_infer_request()); + auto infer_request_prefill = std::make_shared(compiled_model_prefill.create_infer_request()); + auto infer_request_decode = std::make_shared(compiled_model_decode.create_infer_request()); compile_end_time = ggml_time_us(); model = is_prefill ? model_prefill : model_decode; ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key]; - r_ctx->decoder_cache[key] = ggml_decoder; + infer_request = is_prefill ? infer_request_prefill : infer_request_decode; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -391,18 +437,29 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache_prefill[key] = infer_request_prefill; + r_ctx->infer_request_cache[key] = infer_request_decode; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names_local; + std::vector ov_output_names_local; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names_local = r_ctx->ov_input_names_cache[key]; + ov_output_names_local = r_ctx->ov_output_names_cache[key]; + } if (is_prefill) { auto inp_len = inp_pos->ne[0]; for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); infer_request->set_input_tensor(i, input_tensor); @@ -412,8 +469,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -421,16 +478,16 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer(); if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { - for (size_t i = 0; i < ov_output_names.size(); i++) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { const auto output_tensor = infer_request->get_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } infer_end_time = ggml_time_us(); } else { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); infer_request->set_input_tensor(i, input_tensor); @@ -440,8 +497,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -450,9 +507,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 656573d1389..2c72e33c352 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -3,12 +3,15 @@ #include "ggml-impl.h" #include +#include #include #include +#include #include #include #include #include +#include #include struct graph_key { @@ -40,11 +43,17 @@ struct graph_key_hash { } }; +struct decoder_runtime_ctx { + decoder_runtime_ctx(std::shared_ptr mutex) : mutex(std::move(mutex)) {} + std::shared_ptr mutex; + std::shared_ptr ptr; +}; + struct ov_runtime_context { - std::mutex ov_compute_mutex; + mutable std::mutex ctx_mutex; std::string device; bool stateful; - std::unordered_map, graph_key_hash> decoder_cache; + std::unordered_map, graph_key_hash> decoder_cache; std::unordered_map, graph_key_hash> infer_request_cache; std::unordered_map, graph_key_hash> infer_request_cache_prefill; std::unordered_map, graph_key_hash> ov_input_names_cache; @@ -53,11 +62,22 @@ struct ov_runtime_context { // Simultanous stateful inference request support to be added. size_t stateful_kv_size; std::map kv_state_input_name_map; + std::atomic backend_count; ov_runtime_context() : device("CPU"), stateful(false), - stateful_kv_size(0) {} + stateful_kv_size(0), + backend_count(0) {} + + void clear_caches() { + std::lock_guard lock(ctx_mutex); + decoder_cache.clear(); + infer_request_cache.clear(); + infer_request_cache_prefill.clear(); + ov_input_names_cache.clear(); + ov_output_names_cache.clear(); + } }; enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); From e2014d6959fd6194d434daf1ea199715b427beba Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Wed, 22 Apr 2026 04:53:44 +0800 Subject: [PATCH 170/249] hexagon: fix missing v79 entry in libggml-htp.inf (llama/22194) --- ggml/src/ggml-hexagon/libggml-htp.inf | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 656d2d9ab26..360d8b1228e 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -18,6 +18,7 @@ libggml-htp-v68.so = 1 libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 +libggml-htp-v79.so = 1 libggml-htp-v81.so = 1 [ControlFlags] @@ -31,6 +32,7 @@ libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE [Strings] From 84a6b5c03903504ccfd7bfad321d8c6dc9fbd708 Mon Sep 17 00:00:00 2001 From: Shreya Jain Date: Tue, 21 Apr 2026 14:16:04 -0700 Subject: [PATCH 171/249] Hexagon: DAIG op (llama/22195) * hexagon: Add DIAG op * hexagon: add HVX support and DMA double buffering * hexagon: fix fatal error * hexagon: remove as many pragma(s) as possible --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 28 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/diag-ops.c | 216 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 250 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/diag-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d68b80048f..5e206c5e9de 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2596,6 +2596,29 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; + } + + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2632,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { @@ -3159,6 +3183,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 9ca759459d4..82c10b57bbf 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + diag-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 00000000000..9b3194d9084 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 8b5e47adef8..038941af0f2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,5 +98,6 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 79b5ecd2270..002dd1c12d2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_DIAG, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 5091623a653..d633145c909 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_DIAG: + return op_diag(octx); + case HTP_OP_INVALID: break; From 2e5eb6e9512a51129827698576307c7d4f5148d4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Wed, 22 Apr 2026 08:05:21 +0900 Subject: [PATCH 172/249] ggml-webgpu: reset CPU/GPU profiling time when freeing context (llama/22050) * Reset the CPU/GPU profiling time when freeing context. * move GPU profiling time from global context to webgpu_context. --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa20a745e0a..a2923145230 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -211,6 +211,7 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) std::unordered_map cpu_time_ms; @@ -218,11 +219,6 @@ struct webgpu_global_context_struct { std::unordered_map cpu_detail_ms; #endif -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; -#endif - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; @@ -268,10 +264,12 @@ struct webgpu_context_struct { size_t memset_bytes_per_thread; #ifdef GGML_WEBGPU_GPU_PROFILE - wgpu::Buffer profile_timestamp_dev_buf; - wgpu::Buffer profile_timestamp_host_buf; - wgpu::QuerySet profile_timestamp_query_set; - uint32_t profile_timestamp_query_count = 0; + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; #endif ~webgpu_context_struct() { @@ -713,12 +711,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) << pct << "%)\n"; @@ -2511,7 +2509,7 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & for (size_t i = 0; i < pipeline_names.size(); ++i) { // WebGPU timestamps are in ns; convert to ms. const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; - ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; } ctx->profile_timestamp_host_buf.Unmap(); From d6a417408c5a764ff484a0210d5d99a55af9d8c9 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 22 Apr 2026 04:54:20 +0530 Subject: [PATCH 173/249] hexagon: add support for FILL op (llama/22198) Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 16 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/fill-ops.c | 123 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 145 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/fill-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5e206c5e9de..cdd9fcf5928 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2655,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: @@ -3053,6 +3054,17 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * dst = op; + + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3183,6 +3195,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_FILL: + supp = ggml_hexagon_supported_fill(sess, op); + break; + case GGML_OP_DIAG: supp = ggml_hexagon_supported_diag(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 82c10b57bbf..b1ae60a9c43 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + fill-ops.c diag-ops.c ) diff --git a/ggml/src/ggml-hexagon/htp/fill-ops.c b/ggml/src/ggml-hexagon/htp/fill-ops.c new file mode 100644 index 00000000000..3ccfbe74ee4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/fill-ops.c @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-copy.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +// ggml op_params layout for FILL: +// op_params[0] (as float) - the scalar fill value + +#define fill_preamble \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne1 * ne2 * ne3; + +struct htp_fill_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t total_rows; // ne1 * ne2 * ne3 + bool opt_path; + HVX_Vector splat_vec; + uint32_t elem_size; +}; + +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; + struct htp_ops_context * octx = fctx->octx; + fill_preamble; + + // Parallelise over the flat row index spanning ne1*ne2*ne3 + const uint32_t ir0 = fctx->nrows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + if (fctx->opt_path) { + // Opt path: tensor is fully contiguous, treat as flat array + const uint32_t elem_start = ir0 * ne0; + const uint32_t elem_end = ir1 * ne0; + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); + } else { + // Non-contiguous path: must respect strides + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); + } + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_fill(struct htp_ops_context * octx) { + fill_preamble; + + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. + const uint32_t n_threads = MIN(nr, octx->n_threads); + + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); + + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); + + float val_f32 = 0.f; + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); + + struct htp_fill_context fctx = { + .octx = octx, + .nrows_per_thread = (nr + n_threads - 1) / n_threads, + .total_rows = nr, + .opt_path = opt_path, + }; + + switch (dst->type) { + case HTP_TYPE_F32: + fctx.splat_vec = hvx_vec_splat_f32(val_f32); + fctx.elem_size = sizeof(float); + break; + case HTP_TYPE_F16: + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); + fctx.elem_size = sizeof(_Float16); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 038941af0f2..78455e6b071 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,6 +98,7 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 002dd1c12d2..62d6ec02241 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_FILL, HTP_OP_DIAG, HTP_OP_INVALID diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index d633145c909..9185c9ffe15 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_FILL: + return op_fill(octx); + case HTP_OP_DIAG: return op_diag(octx); From 447be522e91bc83679fa714eb40f3e994c2aaa73 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Tue, 21 Apr 2026 23:18:57 -0400 Subject: [PATCH 174/249] ggml-webgpu(shader): support conv2d kernels. (llama/21964) * ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops * shader(conv2d): add conv2d shader kernels and pass f32 and f16 tests * shader(conv2d): fix the out of bounds memory access in the weight indexing * shader(conv2d): clean unused variables and optimize the computation * merge: use the new entries function * clean: address the formatting issues * clean: address the warning issues * clear: clean the shader editorconfig-checker issues * clear: clean the shader editorconfig-checker with utf-8 --------- Co-authored-by: Jeremy J. Hartmann --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 63 +++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 89 ++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl | 165 ++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d88f98050e..f84dfee9d39 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -240,6 +240,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key { } }; +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -789,6 +810,8 @@ class ggml_webgpu_shader_lib { rope_pipelines; std::unordered_map soft_max_pipelines; + std::unordered_map + conv2d_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2382,6 +2405,46 @@ class ggml_webgpu_shader_lib { return soft_max_pipelines[key]; } + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a2923145230..551586751c0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include @@ -921,6 +922,87 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + uint32_t max_wg_size = + std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); + uint32_t wg_size = + std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = wg_size; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t n_out = ggml_nelements(dst); + uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); + uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2477,6 +2559,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_SUM: case GGML_OP_SUM_ROWS: return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); default: return std::nullopt; } @@ -3495,6 +3579,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOLVE_TRI: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 00000000000..9eb131dc221 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var weights: array; +#elif defined(WEIGHT_F16) +var weights: array; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +} From c5bb7c0078d94cbf6f85caa5a7bc19cf310d846f Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Wed, 22 Apr 2026 18:02:56 +0530 Subject: [PATCH 175/249] sycl: Improve mul_mat_id memory efficiency and add BF16 fast path (llama/22119) * sycl: size mul_mat_id staging buffers by routed rows Previously src1_contiguous/dst_contiguous in ggml_sycl_mul_mat_id were sized to ggml_nelements(src1/dst), which over-allocates when ne12 > 1 and can fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero for MoE models (notably with --cpu-moe). Size them by the actual number of routed rows (ids->ne[1] * n_ids) instead. * sycl: add bf16 mul_mat fast path via DNNL When src0 is BF16 (commonly the case for lm_head / output.weight), the existing f16 path is skipped because bf16 isn't covered, and the f32 fallback dequantizes the entire src0 slab to f32 in a single pool alloc (row_diff*ne00 floats). For large-vocab models this can reach several GB and fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero. Add a bf16xbf16 -> f32 DNNL matmul fast path that uses the bf16 storage in place and only materializes a small src1 bf16 conversion buffer. bf16 matmul accumulates in f32, so it's correct even when the op requests GGML_PREC_F32 (as lm_head does). - gemm.hpp: map bfloat16 to dnnl::memory::data_type::bf16. - convert.{hpp,cpp}: expose ggml_get_to_bf16_sycl for f32/f16/bf16 -> bf16. - ggml-sycl.cpp: take the bf16 path early in ggml_sycl_op_mul_mat_sycl when DNNL and GGML_SYCL_HAS_BF16 are both available. --- ggml/src/ggml-sycl/common.hpp | 7 +++++++ ggml/src/ggml-sycl/convert.cpp | 23 ++++++++++++++++------- ggml/src/ggml-sycl/convert.hpp | 9 +++++++++ ggml/src/ggml-sycl/gemm.hpp | 3 +++ ggml/src/ggml-sycl/ggml-sycl.cpp | 30 ++++++++++++++++++++++++++++-- ggml/src/ggml-sycl/set_rows.cpp | 8 +++++++- 6 files changed, 70 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fd84c917853..0101b27640a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -28,6 +28,13 @@ namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include() + #include + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f3c521b45f6..67b9c06f3e4 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include() - #include - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { } +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + case GGML_TYPE_BF16: + return convert_unary_sycl; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 6e621f2154d..8de79d10ff6 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,6 +23,11 @@ typedef to_t_sycl_t to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, @@ -35,15 +40,19 @@ template inline dst_t ggml_sycl_cast(src_t x) { if constexpr (std::is_same_v) { return x; +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v) { return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); +#endif } else if constexpr (std::is_same_v && std::is_same_v) { return x.template convert(); +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v && std::is_same_v>) { return {x.x, x.y}; +#endif } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aeeb4..c202da110be 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ class DnnlGemmWrapper { static constexpr dt to_dt() { if constexpr (std::is_same_v) return dt::f32; else if constexpr (std::is_same_v) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c02a41ad862..3829da87903 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2176,6 +2176,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); @@ -3848,8 +3873,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c100913..8fb41943525 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v || std::is_same_v || std::is_same_v; + return std::is_arithmetic_v || std::is_same_v +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break; From 0fbe4c4ca7a4c42c217a8f2ada04677c33a20b2c Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 23 Apr 2026 02:51:40 +0900 Subject: [PATCH 176/249] ggml-webgpu: Add fused RMS_NORM + MUL (llama/21983) * fused rms_norm_mul + mul * Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion. * Decouple num_fused_ops from webgpu_context; misc cleanup * Fix eps handling and remove disable_fusion. * Fix not to use c++20 initializers. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 71 +++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 157 ++++++++++++++++-- .../wgsl-shaders/rms_norm_mul.wgsl | 139 ++++++++++++++++ 3 files changed, 349 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f84dfee9d39..6593a9fe16b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { } }; +/** RMS_NORM + MUL **/ + +struct ggml_webgpu_rms_norm_mul_pipeline_key { + bool inplace; + bool src_overlap; + + bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { + return inplace == other.inplace && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib { std::unordered_map cumsum_pipelines; // key is fixed, no variants yet std::unordered_map row_norm_pipelines; // op/inplace + std::unordered_map - get_rows_pipelines; // src_type, vectorized + get_rows_pipelines; // src_type, vectorized std::unordered_map - unary_pipelines; // type/op/inplace + unary_pipelines; // type/op/inplace std::unordered_map - scale_pipelines; // inplace + scale_pipelines; // inplace std::unordered_map - solve_tri_pipelines; // type + solve_tri_pipelines; // type std::unordered_map - ssm_conv_pipelines; // type/vectorized + ssm_conv_pipelines; // type/vectorized std::unordered_map @@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib { std::unordered_map conv2d_pipelines; + std::unordered_map + rms_norm_mul_pipelines; + public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib { return unary_pipelines[key]; } + webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rms_norm_mul_pipeline_key key = {}; + key.inplace = context.inplace; + key.src_overlap = context.src_overlap; + + auto it = rms_norm_mul_pipelines.find(key); + if (it != rms_norm_mul_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = "RMS_NORM_MUL"; + std::string variant = op_name; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rms_norm_mul_pipelines[key] = pipeline; + return rms_norm_mul_pipelines[key]; + } + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 551586751c0..5d3169904c5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static std::optional ggml_webgpu_rms_norm_mul(webgpu_context & ctx, + ggml_tensor * rn_src, + ggml_tensor * rn_dst, + ggml_tensor * mul_src0, + ggml_tensor * mul_src1, + ggml_tensor * dst) { + ggml_tensor * mul_src; + + if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) { + mul_src = mul_src1; + } else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) { + mul_src = mul_src0; + } else { + GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); + } + + bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); + + uint32_t offset_merged_rn_src = 0; + uint32_t offset_merged_mul_src = 0; + size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); + size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); + + if (src_overlap) { + size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + offset_merged_rn_src = + (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); + offset_merged_mul_src = + (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); + } + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), + offset_merged_rn_src, + offset_merged_mul_src, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)), + (uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) mul_src->ne[0], + (uint32_t) mul_src->ne[1], + (uint32_t) mul_src->ne[2], + (uint32_t) mul_src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader + }; + + std::vector entries; + + if (inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + } else if (src_overlap) { + size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + size_t merged_end = + std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), + mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; + shader_lib_ctx.src_overlap = src_overlap; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} + static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + return false; + } + + // additional constraints specific to this fusion + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + return true; +} + // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode(webgpu_context ctx, + ggml_cgraph * cgraph, + int node_idx, + int & num_encoded_ops) { + ggml_tensor ** nodes = cgraph->nodes; + ggml_tensor * node = nodes[node_idx]; + if (ggml_is_empty(node)) { return std::nullopt; } if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { return std::nullopt; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; @@ -2519,6 +2640,13 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: + if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) { + num_encoded_ops = 2; + ggml_tensor * mul_node = nodes[node_idx + 1]; + return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node); + } else { + return ggml_webgpu_row_norm(ctx, src0, node); + } case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_inflight_batches = 0; bool contains_set_rows = false; bool batch_compute_passes = true; + int num_encoded_ops = 1; + int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; @@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + while (node_idx < cgraph->n_nodes) { + if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->param_arena.reset(); commands.clear(); } + + node_idx += num_encoded_ops; + num_encoded_ops = 1; } if (ctx->active_compute_pass) { @@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; - webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_arena.init( webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, @@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl new file mode 100644 index 00000000000..71f063b51aa --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -0,0 +1,139 @@ +#ifdef INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif SRC_OVERLAP + +@group(0) @binding(0) +var merged_src: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset]; +} + +#else + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#endif + +struct Params { + offset_rn_src: u32, + offset_mul_src: u32, + offset_merged_rn_src: u32, + offset_merged_mul_src: u32, + offset_dst: u32, + + stride_rn_src1: u32, + stride_rn_src2: u32, + stride_rn_src3: u32, + + stride_mul_src1: u32, + stride_mul_src2: u32, + stride_mul_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + mul_src_ne0: u32, + mul_src_ne1: u32, + mul_src_ne2: u32, + mul_src_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } +#ifdef SRC_OVERLAP + sum += pow(merged_src[i_rn_src_row + col], 2.0); +#else + sum += pow(rn_src[i_rn_src_row + col], 2.0); +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0); + col += WG_SIZE; + } +} From d2a26dc8e26edc72f0ba1b9d9f727d34625c9c7b Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Wed, 22 Apr 2026 10:52:01 -0700 Subject: [PATCH 177/249] Implement async tensor api and event api (llama/22099) * Only run webgpu CI on my fork * Implement set_tensor_async * Implement synchronize api * Implement event creation and deletion API * Cleanup * Cleanup * Comment out jobs for local CI run * Add webgpu only workflow * Delete .github/workflows/build-webgpu.yml * Cleanup * Cleanup * Update API with function handlers * Run clang-format * Replace one-shot buffer with a direct queue.WriteBuffer using the buffer context --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 99 ++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5d3169904c5..44e3bf82216 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2832,22 +2832,107 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str return GGML_STATUS_SUCCESS; } +struct ggml_backend_webgpu_event_context { + webgpu_global_context global_ctx; + wgpu::Future future; + bool recorded = false; +}; + +static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context; + + auto * event_ctx = new ggml_backend_webgpu_event_context(); + event_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + + auto * event = new ggml_backend_event; + event->device = device; + event->context = event_ctx; + return event; +} + +static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + delete static_cast(event->context); + delete event; +} + +static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + if (!event_ctx->recorded) { + return; + } + wgpu::WaitStatus status = + event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + if (status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + event_ctx->recorded = false; +} + +static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + + event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + event_ctx->recorded = true; +} + +static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_UNUSED(backend); + ggml_backend_webgpu_device_event_synchronize(nullptr, event); +} + +static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_UNUSED(backend); + auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + + // Write aligned portion + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx); +} + static ggml_backend_i ggml_backend_webgpu_i = { /* .get_name = */ ggml_backend_webgpu_name, /* .free = */ ggml_backend_webgpu_free, - /* .set_tensor_async = */ NULL, + /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_webgpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_webgpu_event_record, + /* .event_wait = */ ggml_backend_webgpu_event_wait, /* .graph_optimize = */ NULL, }; @@ -3810,9 +3895,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .supports_op = */ ggml_backend_webgpu_device_supports_op, /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_webgpu_device_event_new, + /* .event_free = */ ggml_backend_webgpu_device_event_free, + /* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize, }; /* End GGML Backend Device Interface */ From 393fdffe20e5bdd9c0803220c75d53eeca90664f Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 23 Apr 2026 02:34:31 +0200 Subject: [PATCH 178/249] HIP: flip GGML_HIP_GRAPHS to default on (llama/22254) In #11362 hip graph was disabled by default as, at the time, its performance impact was negative. Due to improvements in rocm and our usage and construction of graphs this is no longer true, so lets change the default. --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 2effd587b41..b9f7deb150d 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -213,7 +213,7 @@ set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) -option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph" ON) option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) From b6b547885cd431e457d58bad58eb5e9ba972919b Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Thu, 23 Apr 2026 02:28:56 +0000 Subject: [PATCH 179/249] CUDA: fuse relu + sqr (llama/22249) --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cu | 23 +++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ 3 files changed, 55 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 185956317e0..1c2c3b4ac69 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3592,6 +3592,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * sqr = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != sqr->type) { + return false; + } + + if (!ggml_is_contiguous(unary->src[0])) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -4100,6 +4124,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 4ad30fa1f35..2aeba26f414 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) { return x * x; } +static __device__ __forceinline__ float op_relu_sqr(float x) { + const float r = fmaxf(x, 0.0f); + return r * r; +} + static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } @@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary GGML_ABORT("Unsupported unary op for fused unary+mul"); } } + +/* fused relu + sqr */ + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) { + const ggml_tensor * src = relu_node->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + GGML_ASSERT(src->type == sqr_node->type); + + const int k = ggml_nelements(src); + if (src->type == GGML_TYPE_F16) { + unary_cuda((const half *)src->data, (half *)sqr_node->data, k, stream); + } else { + unary_cuda((const float *)src->data, (float *)sqr_node->data, k, stream); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index f1dd2183a6c..81ed873ecc3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } From df528c4f71cec95e0d024a512f4d36928e4e61e6 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Wed, 22 Apr 2026 23:17:41 -0400 Subject: [PATCH 180/249] ggml-webgpu: add support for im2col (llama/22259) * shader(im2col): implement the im2col shader * shader(im2col): clean the formatting issues * shader(im2col): clean the editorconfig checker warning * fix(shader): address the workgroup issues of im2col and conv2d --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 59 ++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 127 +++++++++++++++--- ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl | 101 ++++++++++++++ 3 files changed, 268 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6593a9fe16b..efc5b8c97a7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -281,6 +281,25 @@ struct ggml_webgpu_conv2d_pipeline_key_hash { } }; +/** Im2Col **/ +struct ggml_webgpu_im2col_pipeline_key { + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_im2col_pipeline_key_hash { + size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -833,6 +852,8 @@ class ggml_webgpu_shader_lib { soft_max_pipelines; std::unordered_map conv2d_pipelines; + std::unordered_map + im2col_pipelines; std::unordered_maptype; + key.output_type = context.dst->type; + + auto it = im2col_pipelines.find(key); + if (it != im2col_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "im2col"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for IM2COL shader"); + } + }; + + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_im2col, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + im2col_pipelines[key] = pipeline; + return im2col_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 44e3bf82216..bcca2bd4627 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - uint32_t max_wg_size = - std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); - uint32_t wg_size = - std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = wg_size; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - uint32_t n_out = ggml_nelements(dst); - uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); - uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1; + + const uint32_t KW = src0->ne[0]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1]; + + const uint32_t IW = src1->ne[0]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2]; + + const uint32_t OW = dst->ne[1]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + + const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)); + const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0; + const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)); + const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)); + + const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)); + const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)); + const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0; + const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) : + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + si0, + si1, + si2, + si3, + so0, + so1, + so2, + so3, + + KW, + KH, + IC, + + IW, + IH, + N, + + OW, + OH, + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); uint32_t wg_y = CEIL_DIV(total_wg, wg_x); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); @@ -1988,8 +2071,8 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || - (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); uint32_t offset_merged_rn_src = 0; @@ -2689,6 +2772,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_sum_rows(ctx, src0, node); case GGML_OP_CONV_2D: return ggml_webgpu_conv_2d(ctx, src0, src1, node); + case GGML_OP_IM2COL: + return ggml_webgpu_im2col(ctx, src0, src1, node); default: return std::nullopt; } @@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; - webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_arena.init( webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, @@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { @@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); break; + case GGML_OP_IM2COL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl new file mode 100644 index 00000000000..386ebab879f --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl @@ -0,0 +1,101 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(1) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + KW: u32, KH: u32, IC: u32, + IW: u32, IH: u32, N: u32, + OW: u32, OH: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +} + +@group(0) @binding(2) +var params: Params; + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let K = params.KW * params.KH * params.IC; + let M = params.OW * params.OH; + let total = K * M * params.N; + + if (i_out >= total) { + return; + } + + // decode (k, m, n) + var i = i_out; + let n = i / (K * M); + i = i % (K * M); + let m = i / K; + let k = i % K; + + // decode (oh, ow) + let oh = m / params.OW; + let ow = m % params.OW; + + // decode (kw, kh, ic) + let kw = k % params.KW; + let tmp = k / params.KW; + let kh = tmp % params.KH; + let ic = tmp / params.KH; + + let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0); + let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1); + + if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) { + let iw = u32(iw_i32); + let ih = u32(ih_i32); + let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3; + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx)); + } else { + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0); + } +} From b938c5026c42d5a1c3a11665fe714cc1768d0823 Mon Sep 17 00:00:00 2001 From: abotsis Date: Wed, 22 Apr 2026 23:18:56 -0600 Subject: [PATCH 181/249] sycl : fused MoE mul_mat_vec_q for TG (llama/21920) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sycl : fused MoE mul_mat_vec_q for TG Create an MMVQ kernel so ggml_sycl_mul_mat_id can consolidate n_experts_used matmuls in a single kernel launch. The kernel also reads expert IDs directly, removing a per-call host sync. This is similar to the CUDA backend's ggml_cuda_mul_mat_vec_q* paths. All types supported in the current MMVQ are supported here as well: Q2_K, Q3_K, Q4_K, Q5_K, Q6_K, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0 It will fall back to the existing per-expert path when src0 has been rewritten by opt_for_reorder(), and for any shape the fused path doesn't handle. test-backend-ops passes for supported type/shape combos. Benchmark: Qwen3-Next-35B-A3B Q4_K_M on Intel Arc B70 (SYCL0), baseline 707c0b7a6, 16k context, -fa 0. build/bin/llama-bench -hf unsloth/Qwen3.5-35B-A3B-GGUF:Q4_K_M \ -p 1024 -n 128 -d 16384 -ngl 99 -fa 0 -ub 2048 -r 2 -dev SYCL0 Before (3 runs on 707c0b7a6): | test | run 1 | run 2 | run 3 | | --------------- | ----------------:| ----------------:| ----------------:| | pp1024 @ d16384 | 533.26 ± 4.87 | 535.20 ± 2.78 | 524.27 ± 3.10 | | tg128 @ d16384 | 33.47 ± 0.02 | 33.31 ± 0.02 | 33.17 ± 0.05 | After (3 runs on 707c0b7a6 + this patch): | test | run 1 | run 2 | run 3 | | --------------- | ----------------:| ----------------:| ----------------:| | pp1024 @ d16384 | 534.06 ± 0.97 | 531.95 ± 0.02 | 520.94 ± 20.10 | | tg128 @ d16384 | 45.85 ± 0.21 | 45.95 ± 0.45 | 46.22 ± 0.12 | disclosure: Claude wrote it, but I reviewed and understand the implementation (albeit my C is a little rusty). * sycl: also support nvfp4 and mxfp4 expert types * sycl: terser comments/nested dispatch in response to review * sycl: more comment cleanup in mmvq.cpp/hpp --------- Co-authored-by: Debian --- ggml/src/ggml-sycl/ggml-sycl.cpp | 51 +++++++++++ ggml/src/ggml-sycl/mmvq.cpp | 151 +++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/mmvq.hpp | 16 ++++ 3 files changed, 218 insertions(+) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3829da87903..36923160d72 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3808,6 +3808,51 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( } } +// Fused MoE TG fast path. Returns false to fall back to the per-expert loop below. +static bool ggml_sycl_mul_mat_id_mmvq_fused( + ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) +{ + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + if (ne12 != 1) return false; + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) return false; + if (ne10 != src0->ne[0] || ne10 % QK8_1 != 0) return false; + if (!ggml_is_contiguous(src1)) return false; + + // Reorder layout not supported; fall back. + const ggml_tensor_extra_gpu * src0_extra = + static_cast(src0->extra); + if (src0_extra && src0_extra->optimized_feature.reorder) return false; + + const int64_t n_ids_per_group = ids->ne[0]; + if (ids->ne[1] != 1) return false; + if (ne11 != 1 && ne11 != n_ids_per_group) return false; + + const queue_ptr stream = ctx.stream(); + const int src1_padded_cols = GGML_PAD((int) ne10, MATRIX_ROW_PADDING); + const int n_experts_used = (int) n_ids_per_group; + const int nrows = (int) src0->ne[1]; + + ggml_sycl_pool_alloc src1_q8_alloc(ctx.pool(), + (size_t) ne11 * src1_padded_cols * sizeof(block_q8_1) / QK8_1); + char * src1_ddq = src1_q8_alloc.get(); + quantize_row_q8_1_sycl( + (const float *) src1->data, src1_ddq, (int) ne10, (int) ne11, + src1_padded_cols, stream); + + const size_t bytes_per_qrow = (size_t) src1_padded_cols * sizeof(block_q8_1) / QK8_1; + const size_t src1_row_stride = (ne11 == 1) ? 0 : bytes_per_qrow; + + return ggml_sycl_mul_mat_vec_q_id( + src0->type, src0->data, src1_ddq, (const int32_t *) ids->data, + (float *) dst->data, (int) ne10, nrows, n_experts_used, + /*expert_weight_stride=*/ src0->nb[2], + /*dst_row_stride=*/ dst->nb[1], + src1_row_stride, stream); +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -3823,6 +3868,12 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + if (ne12 == 1) { + if (ggml_sycl_mul_mat_id_mmvq_fused(ctx, src0, src1, ids, dst)) { + return; + } + } + std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 3a4577ecbbc..8fa2198f35a 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1199,3 +1199,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(src1_ddf_i); GGML_UNUSED(ctx); } + +// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj). +template +static void mul_mat_vec_q_moe( + const void * __restrict__ vx_base, const void * __restrict__ vy_base, + float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev, + const int ncols, const int nrows, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + const sycl::nd_item<3> & item_ct1) { + + const int expert_idx = item_ct1.get_group(1); + const int i02 = ids_dev[expert_idx]; + + const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride; + const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride; + float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void launch_mul_mat_vec_q_moe( + const void * vx_base, const void * vy, const int32_t * ids_dev, + float * dst_base, const int ncols, const int nrows, const int n_experts_used, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + dpct::queue_ptr stream) { + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_moe( + vx_base, vy, dst_base, ids_dev, ncols, nrows, + expert_weight_stride, dst_row_stride, src1_row_stride, item); + }); + }); +} + +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, + const void * vy, + const int32_t * ids_dev, + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, + size_t dst_row_stride, + size_t src1_row_stride, + dpct::queue_ptr stream) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q8_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q2_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q3_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q6_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_MXFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_NVFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp index 049b43d4535..d674dc1d61e 100644 --- a/ggml/src/ggml-sycl/mmvq.hpp +++ b/ggml/src/ggml-sycl/mmvq.hpp @@ -24,4 +24,20 @@ void ggml_sycl_op_mul_mat_vec_q( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream); +// Requires standard (non-reorder) block layout for src0. +// Returns false if src0_type isn't handled; caller should fall back. +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, // start of stacked expert weights + const void * vy, // pre-quantized src1 (Q8_1) + const int32_t * ids_dev, // device-side int32, length n_experts_used + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, // bytes between experts in vx_base + size_t dst_row_stride, // bytes between dst rows + size_t src1_row_stride, // 0 = shared src1, else per-expert stride in bytes + dpct::queue_ptr stream); + #endif // GGML_SYCL_MMVQ_HPP From 1aba06173778618cc3d56ce6201eb703935c5e99 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 23 Apr 2026 08:22:08 +0300 Subject: [PATCH 182/249] ggml-base: use MATH_LIBRARY variable instead of hardcoded 'm' (llama/22239) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #22237 — the find_library(MATH_LIBRARY m) result was being discarded and the target linked against the literal 'm' string. This prevents users from overriding the math library (e.g. for AMD AOCL) via CMake variables. Now the discovered MATH_LIBRARY is used directly. --- ggml/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 48fbe208d90..52754e1b9d6 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -473,7 +473,7 @@ target_link_libraries(ggml-base PRIVATE Threads::Threads) find_library(MATH_LIBRARY m) if (MATH_LIBRARY) if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE m) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) endif() endif() From 682ee993057efd295a81140bae4067d6346e1f92 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 23 Apr 2026 08:22:49 +0300 Subject: [PATCH 183/249] metal : fix event synchronization (llama/22260) --- ggml/src/ggml-metal/ggml-metal-device.m | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 27cb1683518..f17f7e2e0ce 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -931,13 +931,13 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { } struct ggml_metal_event { - void * obj; // id + void * obj; // id atomic_int value; }; void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -945,7 +945,7 @@ void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t } void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -953,7 +953,7 @@ void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cm } ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { - id event = [dev->mtl_device newEvent]; + id event = [dev->mtl_device newSharedEvent]; ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); @@ -964,7 +964,7 @@ ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { } void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { - id event = ev->obj; + id event = ev->obj; [event release]; free(ev); @@ -973,14 +973,13 @@ void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev } void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { - @autoreleasepool { - id event = ev->obj; - - id cmd_buf = [dev->mtl_queue commandBuffer]; - [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; - [cmd_buf commit]; - [cmd_buf waitUntilCompleted]; + id event = ev->obj; + const bool res = [event waitUntilSignaledValue:atomic_load_explicit(&ev->value, memory_order_relaxed) timeoutMS:60000]; + if (!res) { + GGML_ABORT("%s: failed to wait for event\n", __func__); } + + GGML_UNUSED(dev); } void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { From 71b1ab37841177903e7e97420489ede09bd231b1 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 23 Apr 2026 14:17:21 -0700 Subject: [PATCH 184/249] hexagon: add support for basic and extended Op profiling (llama/22269) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hexagon: restore HTP_OPMASK_QUEUE * hexagon: honor OPMASK_SKIP_COMPUTE in hmx-matmul * hex-prof: restore op profiling * hex-prof: enable PMU * hexagon: simplify and improve op-queuing with full profiling support Add separate profile descriptors. * hexagon: remove opsync and rename opmask into opstage opsync is no longer needed since the profiler is fully async now. opmask name was confusing and opstage is more accurate. * hexagon: refactor opbatch queue handling * hexagon: add iface hooks for enabling profiler from the host Also move all the PMU setup stuff out of the hex-utils since it's not inteded for normal use. * hexagon: make profiler mode configurable On older devices getting PMU counters is expensive so it's now optional. * hexagon: add support for setting profiler pmu events from env * hexagon: simplify profiler output (no need to print buffs, etc) * hexagon: simplify pmu counter formating * hexagon: add a simple profile post-proc tool * hex-prof: add support for reading logs from stdin * hexagon: document GGML_HEXAGON_PROFILE * hex-prof: update default width for dims field * hex-prof: fix linter warnings and errors * Update ggml/src/ggml-hexagon/htp/htp-ops.h Co-authored-by: Sigbjørn Skjæret * Update scripts/snapdragon/ggml-hexagon-profile.py Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Trivikram Reddy Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 402 +++++++++++++++--------- ggml/src/ggml-hexagon/htp/hex-utils.h | 28 ++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 5 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 36 ++- ggml/src/ggml-hexagon/htp/htp_iface.idl | 8 +- ggml/src/ggml-hexagon/htp/main.c | 172 +++++++--- ggml/src/ggml-hexagon/htp/matmul-ops.c | 4 + 7 files changed, 442 insertions(+), 213 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cdd9fcf5928..955903418b6 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -12,9 +12,12 @@ #include #include #include +#include +#include #include #include #include +#include #ifdef _WIN32 # include @@ -41,18 +44,26 @@ #include "htp_iface.h" #include "htp-drv.h" +using intvec = std::vector; +using uintvec = std::vector; +using u32vec = std::vector; + static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all static int opt_arch = 0; // autodetect static int opt_etm = 0; static int opt_verbose = 0; -static int opt_profile = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) static int opt_hostbuf = 1; // hostbuf ON by default static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +// Default PMU events, if profiling with PMU (mode=2) is enabled +// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html +// https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html +static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }; + // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -104,19 +115,26 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct } static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, - uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) { + uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, - op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); + + char pmu_str[256] = ""; + if (opt_profile > 1) { + static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); + sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions struct ggml_hexagon_opbatch; -struct ggml_hexagon_opshm; +struct ggml_hexagon_opqueue; struct ggml_hexagon_session { std::string name; @@ -132,8 +150,8 @@ struct ggml_hexagon_session { bool valid_iface; std::atomic op_pending; - ggml_hexagon_opbatch *op_batch; - ggml_hexagon_opshm *op_shm; + ggml_hexagon_opbatch* op_batch; + ggml_hexagon_opqueue* op_queue; ggml_backend_buffer_type buffer_type = {}; ggml_backend_buffer_type repack_buffer_type = {}; @@ -1521,65 +1539,14 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf // Backend session implementation -struct ggml_hexagon_opshm { - ggml_hexagon_shared_buffer *sbuf; - - std::vector block_mask; - size_t block_size; - - uint8_t * base() const { return this->sbuf->base; } - int fd() const { return this->sbuf->fd; } - size_t n_blocks() const { return this->block_mask.size(); } - - ggml_hexagon_opshm(ggml_hexagon_session *sess, size_t max_batch, size_t max_pending) { - size_t n_bufs = HTP_OP_MAX_BUFS; - size_t n_ops = max_batch; - size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; - - block_mask.resize(max_pending, true); - - block_size = sizeof(htp_buf_desc) * n_bufs + - sizeof(htp_tensor) * n_tensors + - sizeof(htp_op_desc) * n_ops; - - sbuf = new ggml_hexagon_shared_buffer(sess, block_size * block_mask.size(), true /* pinned */); - - if (opt_verbose) { - GGML_LOG_INFO("ggml-hex: %s allocated shared buf %zu : block-size %zu max-batch %zu max-pending %zu\n", - sess->c_name(), (size_t) sbuf->size, block_size, max_batch, max_pending); - } - } - - ~ggml_hexagon_opshm() { - delete sbuf; - } - - uint8_t * allocate() { - auto it = std::find(block_mask.begin(), block_mask.end(), true); - if (it == block_mask.end()) - return nullptr; - - unsigned int i = std::distance(block_mask.begin(), it); - uint8_t* addr = sbuf->base + (i * block_size); - block_mask[i] = false; - - HEX_VERBOSE("ggml-hex: %s allocated op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); - return addr; - } - - void release(uint8_t * addr) { - int i = (addr - sbuf->base) / block_size; - block_mask[i] = true; - HEX_VERBOSE("ggml-hex: %s released op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); - } -}; - struct ggml_hexagon_opbatch { - const char* name; + ggml_hexagon_session* sess; - std::vector buffers; - std::vector tensors; - std::vector ops; + std::vector ops; // pointers to original ops + + std::vector h_bufs; // htp buffer descriptors + std::vector h_tens; // htp tensor descriptors + std::vector h_ops; // htp op descriptors std::unordered_map b_map; // buffer fd to index std::unordered_map t_map; // tensor ptr to index @@ -1606,19 +1573,21 @@ struct ggml_hexagon_opbatch { d_map.clear(); } - ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t max_batch) { - name = sess->c_name(); + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) { + this->sess = sess; n_bufs_max = HTP_OP_MAX_BUFS; - n_ops_max = max_batch; + n_ops_max = batch_size; n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; b_vmem_max = HTP_OP_MAX_VMEM; - buffers.resize(n_bufs_max); - tensors.resize(n_tens_max); ops.resize(n_ops_max); + h_bufs.resize(n_bufs_max); + h_tens.resize(n_tens_max); + h_ops.resize(n_ops_max); + b_map.reserve(n_bufs_max); t_map.reserve(n_tens_max); d_map.reserve(n_tens_max); @@ -1640,7 +1609,7 @@ struct ggml_hexagon_opbatch { b_map.insert({sbuf->fd, bi}); - htp_buf_desc &b = buffers[bi]; + htp_buf_desc &b = h_bufs[bi]; b.base = (uint64_t) sbuf->base; b.fd = sbuf->fd; b.size = sbuf->size; @@ -1664,7 +1633,7 @@ struct ggml_hexagon_opbatch { // First lookup by tensor data auto range = d_map.equal_range(t->data); for (auto it = range.first; it != range.second; ++it) { - htp_tensor * h = &tensors[it->second]; + htp_tensor * h = &h_tens[it->second]; if (same_shape(h, t)) { return it->second; } } @@ -1682,7 +1651,7 @@ struct ggml_hexagon_opbatch { uint64_t t_offset = (uint8_t *) t->data - sbuf->base; size_t t_size = ggml_nbytes(t); - htp_tensor &h = tensors[ti]; + htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); h.data = t_offset; h.size = t_size; @@ -1737,65 +1706,170 @@ struct ggml_hexagon_opbatch { // assumes that fit_op() was called first and returned true void add_op(htp_op_code opcode, const struct ggml_tensor * t) { // Add new op - htp_op_desc &o = ops[n_ops++]; + + unsigned int n = n_ops++; GGML_ASSERT(n_ops <= n_ops_max); + ops[n] = t; + + htp_op_desc &o = h_ops[n]; memcpy(&o.params, &t->op_params, sizeof(t->op_params)); o.opcode = opcode; o.flags = 0; - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(name, t, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; } o.dst = add_tensor(t); } +}; + +struct ggml_hexagon_opqueue { + // Shared buffer for storing batches + ggml_hexagon_shared_buffer *shm_buf; + size_t shm_blk_size; + + using opvec = std::vector; + + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time + + ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = batch_size; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + shm_blk_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops + + sizeof(htp_prof_desc) * n_ops; + + shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */); + + op_cache.resize(depth); + start_usec.resize(depth, 0); + + // init done queue + for (unsigned int i = 0; i < depth; i++) { done.push(i); } + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated op-queue : batch-size %zu depth %zu shm-size %zu shm-block-size %zu\n", + sess->c_name(), batch_size, depth, shm_buf->size, shm_blk_size); + } + } - size_t flush(uint8_t * mem_addr, size_t mem_size) { - static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); - static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); - static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + ~ggml_hexagon_opqueue() { + delete shm_buf; + } - const size_t b_size = sizeof(htp_buf_desc) * n_bufs; - const size_t t_size = sizeof(htp_tensor) * n_tens; - const size_t o_size = sizeof(htp_op_desc) * n_ops; + // push new batch + bool push(htp_opbatch_req& req, dspqueue_buffer& dbuf, ggml_hexagon_opbatch* op_batch) { + static_assert(sizeof(htp_opbatch_req) % 8 == 0, "sizeof(htp_opbatch_req) must be multiple of 8"); + static_assert(sizeof(htp_opbatch_rsp) % 8 == 0, "sizeof(htp_opbatch_rsp) must be multiple of 8"); + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + static_assert(sizeof(htp_prof_desc) % 8 == 0, "sizeof(htp_prof_desc) must be multiple of 8"); - const size_t m_size = b_size + t_size + o_size; - GGML_ASSERT(m_size <= mem_size); + if (done.empty()) { return false; } - uint8_t * b_ptr = (uint8_t *) mem_addr; - uint8_t * t_ptr = (uint8_t *) b_ptr + b_size; - uint8_t * o_ptr = (uint8_t *) t_ptr + t_size; + req.id = done.front(); done.pop(); // batch id + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; - memcpy(b_ptr, (void *) buffers.data(), b_size); - memcpy(t_ptr, (void *) tensors.data(), t_size); - memcpy(o_ptr, (void *) ops.data(), o_size); + op_cache[req.id] = op_batch->ops; + start_usec[req.id] = ggml_time_us(); - HEX_VERBOSE("ggml-hex: %s flush-opbatch : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu\n", - name, n_bufs, n_tens, n_ops, b_vmem, b_size, t_size, o_size); + const size_t b_size = sizeof(htp_buf_desc) * req.n_bufs; + const size_t t_size = sizeof(htp_tensor) * req.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * req.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * req.n_ops; + + dbuf.ptr = shm_buf->base + (req.id * shm_blk_size); + dbuf.fd = shm_buf->fd; + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base; + dbuf.size = b_size + t_size + o_size + p_size; + + GGML_ASSERT(dbuf.size <= shm_blk_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * b_ptr = m_ptr; m_ptr += b_size; + uint8_t * t_ptr = m_ptr; m_ptr += t_size; + uint8_t * o_ptr = m_ptr; + + memcpy(b_ptr, (void *) op_batch->h_bufs.data(), b_size); + memcpy(t_ptr, (void *) op_batch->h_tens.data(), t_size); + memcpy(o_ptr, (void *) op_batch->h_ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s op-queue push batch #%u : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu m-size %zu\n", + shm_buf->sess->c_name(), req.id, req.n_bufs, req.n_tensors, req.n_ops, op_batch->b_vmem, + b_size, t_size, o_size, (size_t) dbuf.size); + + op_batch->reset(); if (opt_verbose > 1) { htp_buf_desc *b = (htp_buf_desc*) b_ptr; - for (unsigned int i=0; i < n_bufs; i++) { - GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", name, i, + for (unsigned int i=0; i < req.n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", shm_buf->sess->c_name(), i, b[i].fd, (void *) b[i].base, (size_t) b[i].size); } htp_tensor *t = (htp_tensor*) t_ptr; - for (unsigned int i=0; i < n_tens; i++) { + for (unsigned int i=0; i < req.n_tensors; i++) { GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", - name, i, t[i].bi, t[i].data, t[i].size, + shm_buf->sess->c_name(), i, t[i].bi, t[i].data, t[i].size, (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); } } - reset(); + return true; + } + + void pop(htp_opbatch_rsp rsp, dspqueue_buffer dbuf) { + GGML_ASSERT(rsp.id < op_cache.size()); + + done.push(rsp.id); + + const size_t b_size = sizeof(htp_buf_desc) * rsp.n_bufs; + const size_t t_size = sizeof(htp_tensor) * rsp.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops; - return m_size; + const size_t m_size = b_size + t_size + o_size + p_size; + GGML_ASSERT(m_size <= shm_blk_size); + + HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n", + shm_buf->sess->c_name(), rsp.id, rsp.n_bufs, rsp.n_tensors, rsp.n_ops, + (size_t) dbuf.size, b_size, t_size, o_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * p_ptr = m_ptr + (b_size + t_size + o_size); + + if (opt_profile && rsp.n_ops > 0) { + auto & ops = op_cache[rsp.id]; + + uint64_t batch_usec = ggml_time_us() - start_usec[rsp.id]; + uint32_t htp_usec = 0; + + GGML_ASSERT(rsp.n_ops <= ops.size()); + + const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr; + for (uint32_t i = 0; i < rsp.n_ops; i++) { + htp_usec += pd[i].usecs; + ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n", + shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec); + } } }; @@ -1824,17 +1898,12 @@ void ggml_hexagon_session::flush_pending(bool all) { GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); } - op_shm->release((uint8_t*) dbuf.ptr); - if (rsp.status != HTP_STATUS_OK) { GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); // TODO: handle errors } - // FIXME: profile will be per opreq - // this->prof_usecs = rsp.prof_usecs; - // this->prof_cycles = rsp.prof_cycles; - // this->prof_pkts = rsp.prof_pkts; + op_queue->pop(rsp, dbuf); this->op_pending--; // atomic dec @@ -1845,28 +1914,17 @@ void ggml_hexagon_session::flush_pending(bool all) { void ggml_hexagon_session::flush_batch() { if (op_batch->empty()) { return; } - htp_opbatch_req req; - req.n_bufs = op_batch->n_bufs; - req.n_tensors = op_batch->n_tens; - req.n_ops = op_batch->n_ops; + htp_opbatch_req req {}; + dspqueue_buffer dbuf{}; - dspqueue_buffer dbuf; - dbuf.fd = op_shm->fd(); - dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - dbuf.ptr = op_shm->allocate(); - if (!dbuf.ptr) { + if (!op_queue->push(req, dbuf, op_batch)) { flush_pending(false); - dbuf.ptr = op_shm->allocate(); + op_queue->push(req, dbuf, op_batch); } - dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) op_shm->base(); - dbuf.size = op_batch->flush((uint8_t*) dbuf.ptr, op_shm->block_size); - // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc - HEX_VERBOSE("ggml-hex: %s: queue-opbatch : %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); - int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); if (err != 0) { GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); @@ -2016,25 +2074,33 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } if (opt_etm) { - err = htp_iface_enable_etm(this->handle); + err = htp_iface_etm(this->handle, 1); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); } } - // Start the DSP-side service. We need to pass the queue ID to the - // DSP in a FastRPC call; the DSP side will import the queue and start - // listening for packets in a callback. + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + std::copy(opt_pmu_evt.begin(), opt_pmu_evt.end(), pmu_conf.events); + + err = htp_iface_profiler(this->handle, opt_profile, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable profiling: 0x%08x\n", (unsigned) err); + } + } + + // Allocate buffers and state for op batching + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); + this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); + + // Start processing op batch requests err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; - - // Allocate buffers and state for op batching - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); - this->op_shm = new ggml_hexagon_opshm(this, opt_opbatch, opt_opqueue); } void ggml_hexagon_session::release() noexcept(true) { @@ -2043,7 +2109,7 @@ void ggml_hexagon_session::release() noexcept(true) { int err; delete this->op_batch; - delete this->op_shm; + delete this->op_queue; // Stop the DSP-side service and close the queue if (this->valid_iface) { @@ -2054,12 +2120,20 @@ void ggml_hexagon_session::release() noexcept(true) { } if (opt_etm) { - err = htp_iface_disable_etm(this->handle); + err = htp_iface_etm(this->handle, 0); if (err != 0) { GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); } } + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + err = htp_iface_profiler(this->handle, 0, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable profiling: 0x%08x\n", (unsigned) err); + } + } + if (this->valid_queue) { err = dspqueue_close(queue); if (err != 0) { @@ -2077,7 +2151,7 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n repack_buffer_type.device = dev; op_batch = nullptr; - op_shm = nullptr; + op_queue = nullptr; try { allocate(dev_id); @@ -2698,7 +2772,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * n = graph->nodes[i]; - if (op_is_compute(n)) { + if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { sess->enqueue_op(op_remap_to_htp(n), n); } } @@ -3338,6 +3412,26 @@ static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, cons return NULL; } +template std::vector str_to_vec(const char* str) { + std::stringstream ss(str); + std::vector v; + std::string t; + + while (std::getline(ss, t, ',')) { + v.push_back(std::stoul(t, nullptr, 0)); + } + + return v; +} + +template std::string vec_to_str(std::vector v) { + std::stringstream ss; + ss << std::setbase(BASE) << std::showbase; + for (auto i : v) { ss << i << ','; } + auto str = ss.str(); str.pop_back(); // drop last comma + return str; +} + static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, @@ -3351,8 +3445,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); - const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); @@ -3365,19 +3458,30 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { auto RE_ICASE = std::regex_constants::icase; - opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; - opt_opsync = str_opsync ? atoi(str_opsync) : opt_opsync; - opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; - opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_profile = str_profile ? atoi(str_profile) : 0; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + + if (str_profile) { + opt_pmu_evt = [&]() -> std::vector { + auto v = str_to_vec(str_profile); + switch (v.size()) { + case 1: opt_profile = v[0]; return opt_pmu_evt; // mode with default pmu events + case 8: opt_profile = 2; return v; // mode with custom pmu events + default: opt_profile = 0; return {}; // garbage input + }}(); + if (opt_profile == 1) opt_pmu_evt = {}; + GGML_LOG_INFO("ggml-hex: Profiling mode %u : pmu-evt [ %s ]\n", opt_profile, + vec_to_str(opt_pmu_evt).c_str()); + } if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { opt_ndev = GGML_HEXAGON_MAX_SESSIONS; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index f6713c5cf8f..329249e11da 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "hexagon_types.h" #include "hexagon_protos.h" @@ -100,4 +101,31 @@ static inline void hex_pause() { asm volatile(" pause(#255)\n"); } +#ifndef HEX_NUM_PMU_COUNTERS +#define HEX_NUM_PMU_COUNTERS 8 +#endif + +static inline void hex_get_pmu(uint32_t counters[]) { +#if __HVX_ARCH__ >= 79 + asm volatile("%0 = upmucnt0" : "=r"(counters[0])); + asm volatile("%0 = upmucnt1" : "=r"(counters[1])); + asm volatile("%0 = upmucnt2" : "=r"(counters[2])); + asm volatile("%0 = upmucnt3" : "=r"(counters[3])); + asm volatile("%0 = upmucnt4" : "=r"(counters[4])); + asm volatile("%0 = upmucnt5" : "=r"(counters[5])); + asm volatile("%0 = upmucnt6" : "=r"(counters[6])); + asm volatile("%0 = upmucnt7" : "=r"(counters[7])); +#else + counters[0] = qurt_pmu_get(QURT_PMUCNT0); + counters[1] = qurt_pmu_get(QURT_PMUCNT1); + counters[2] = qurt_pmu_get(QURT_PMUCNT2); + counters[3] = qurt_pmu_get(QURT_PMUCNT3); + counters[4] = qurt_pmu_get(QURT_PMUCNT4); + counters[5] = qurt_pmu_get(QURT_PMUCNT5); + counters[6] = qurt_pmu_get(QURT_PMUCNT6); + counters[7] = qurt_pmu_get(QURT_PMUCNT7); + // qurt_pmu_get_pmucnt(counters); +#endif +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 78455e6b071..f8c89211aed 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -10,6 +10,7 @@ #include #include #include +#include #define HTP_MAX_NTHREADS 10 #define HTP_MAX_MMAPS 16 @@ -66,7 +67,9 @@ struct htp_context { int thread_id; int thread_prio; - int hmx_enabled; + bool hmx_enabled; + bool etm; + uint32_t profiler; uint8_t * vtcm_base; size_t vtcm_size; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 62d6ec02241..56d7b398d10 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -42,9 +42,9 @@ enum htp_data_type { // Mask to enable various stages of the Ops. // Used for debugging and profiling. -enum htp_op_mask { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute +enum htp_op_stage { + HTP_OPSTAGE_QUEUE = (1 << 0), // Enable Queueing (ie calls into NPU) + HTP_OPSTAGE_COMPUTE = (1 << 1), // Enable Compute }; // Do not reorder first 4 (used as an index) @@ -137,27 +137,45 @@ struct htp_op_desc { int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices uint16_t dst; // Output tensor index +}; + +enum htp_profiler_mode { + HTP_PROF_DISABLED = 0, + HTP_PROF_BASIC = 1, + HTP_PROF_PMU = 2, +}; + +#define HTP_PROF_PMU_NCNT 8 - // the rest is filled in-place by the NPU - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint32_t unused; +// Profile descriptor +struct htp_prof_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t usecs; // Number of usec + uint32_t cycles; // Number of cycles + uint32_t pad; // Unused + uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters }; struct htp_opbatch_req { + uint32_t id; // Batch id uint32_t n_bufs; // Number of buffers uint32_t n_tensors; // Number of tensors uint32_t n_ops; // Number of ops uint32_t flags; // unused + uint32_t pad; // unused // struct htp_buf_desc bufs[]; -- dspqueue buf 0 // struct htp_tensor tensors[]; -- dspqueue buf 0 // struct htp_op_desc ops[]; -- dspqueue buf 0 }; struct htp_opbatch_rsp { + uint32_t id; // Batch id uint32_t status; // HTP_STATUS_... - // struct htp_op_req ops[]; -- dspqueue buf 0 + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of op profile descriptors + uint32_t pad; // unused + // struct htp_prof_desc profs[]; -- dspqueue buf 0 }; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 3eb5d5a6912..dbcafd1d856 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -6,13 +6,17 @@ #include "AEEStdDef.idl" #include "remote.idl" +struct htp_iface_pmu_conf { + uint32 events[8]; +}; + interface htp_iface : remote_handle64 { AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); AEEResult munmap(in uint32 fd); - AEEResult enable_etm(); - AEEResult disable_etm(); + AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); + AEEResult etm(in uint32 enable); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 9185c9ffe15..088434a63e9 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -27,6 +27,7 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" +#include "htp_iface.h" #include "worker-pool.h" AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { @@ -103,6 +104,54 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_SUCCESS; } +AEEResult htp_iface_etm(remote_handle64 handle, uint32_t enable) { + int err = enable ? HAP_user_etm_enable() : HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable/disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable/disable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_profiler(remote_handle64 handle, uint32_t mode, const htp_iface_pmu_conf* pmu_conf) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (mode == HTP_PROF_PMU) { + const uint32_t* events = pmu_conf->events; + + // Pack 4 event IDs (low 8 bits) into each 32-bit config register + uint32_t evtcfg = 0, evtcfg1 = 0, cfg = 0, i = 0; + for (; i < HEX_NUM_PMU_COUNTERS/2; i++) { + evtcfg |= ((events[i + 0] & 0xFF) << (i * 8)); + evtcfg1 |= ((events[i + 4] & 0xFF) << (i * 8)); + } + + // For events >255 pack high 2 bits of all 8 event IDs into cfg register + // 2 bits per counter: bits [1:0] for counter 0, [3:2] for counter 1, etc. + for (i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + cfg |= (((events[i] >> 8) & 3) << (i * 2)); + } + + FARF(ALWAYS, "Configuring PMU registers: evtcfg = 0x%x, evtcfg1 = 0x%x, pmucfg = 0x%x", evtcfg, evtcfg1, cfg); + + // Configure PMU registers + qurt_pmu_set(QURT_PMUCFG, cfg); + qurt_pmu_set(QURT_PMUEVTCFG, evtcfg); + qurt_pmu_set(QURT_PMUEVTCFG1, evtcfg1); + qurt_pmu_enable(1); + } + + ctx->profiler = mode; + + return AEE_SUCCESS; +} + AEEResult htp_iface_close(remote_handle64 handle) { struct htp_context * ctx = (struct htp_context *) handle; @@ -129,35 +178,19 @@ AEEResult htp_iface_close(remote_handle64 handle) { } } - free(ctx); - return AEE_SUCCESS; -} - -AEEResult htp_iface_enable_etm(remote_handle64 handle) { - int err = HAP_user_etm_enable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); - } + if (ctx->profiler) { + qurt_pmu_enable(1); } - return err; -} -AEEResult htp_iface_disable_etm(remote_handle64 handle) { - int err = HAP_user_etm_disable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); - } + if (ctx->etm) { + HAP_user_etm_disable(); } - return err; + + free(ctx); + return AEE_SUCCESS; } -AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t pinned) { +AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -204,7 +237,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t return AEE_ENOMEMORY; } -AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { +AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -434,19 +467,39 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) { struct profile_data { uint64_t usecs; uint64_t cycles; - uint64_t pkts; + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; }; -static inline void profile_start(struct profile_data * d) { - d->usecs = HAP_perf_get_qtimer_count(); - d->cycles = hex_get_cycles(); - d->pkts = hex_get_pktcnt(); +static inline void profile_start(uint32_t mode, struct profile_data * d) { + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(d->pmu_counters); + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = hex_get_cycles(); + break; + default: + break; + } } -static inline void profile_stop(struct profile_data * d) { - d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); - d->cycles = hex_get_cycles() - d->cycles; - d->pkts = hex_get_pktcnt() - d->pkts; +static inline void profile_stop(uint32_t mode, struct profile_data * d) { + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(pmu_counters); + for (int i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + d->pmu_counters[i] = pmu_counters[i] - d->pmu_counters[i]; + } + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = hex_get_cycles() - d->cycles; + break; + default: + break; + } } static int execute_op(struct htp_ops_context * octx) { @@ -726,29 +779,32 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { continue; } + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; + const uint32_t n_bufs = req.n_bufs; const uint32_t n_tens = req.n_tensors; const uint32_t n_ops = req.n_ops; - const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; - const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; - const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops; - if (dbuf.size < b_size + t_size + o_size) { + if (dbuf.size < b_size + t_size + o_size + p_size) { FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); break; } - // Reset poll count for valid requests - poll_count = DSPQUEUE_POLL_COUNT; + FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id, + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + // Setup descriptor pointers uint8_t * m_ptr = dbuf.ptr; - struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; - struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; - struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; - - FARF(HIGH, "processing opbatch: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", - n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; m_ptr += o_size; + struct htp_prof_desc* pds = (struct htp_prof_desc*) m_ptr; prep_op_bufs(ctx, bufs, n_bufs); prep_tensors(ctx, bufs, tens, n_tens); @@ -760,22 +816,34 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; - profile_start(&prof); + + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); - profile_stop(&prof); - ops[i].prof_usecs = prof.usecs; - ops[i].prof_cycles = prof.cycles; - ops[i].prof_pkts = prof.pkts; + profile_stop(ctx->profiler, &prof); + + if (ctx->profiler) { + pds[i].opcode = ops[i].opcode; + pds[i].usecs = prof.usecs; + pds[i].cycles = prof.cycles; + for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) { + pds[i].pmu[j] = prof.pmu_counters[j]; + } + } } // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); struct htp_opbatch_rsp rsp; - rsp.status = HTP_STATUS_OK; // FIXME + rsp.id = req.id; + rsp.status = HTP_STATUS_OK; + rsp.n_bufs = n_bufs; + rsp.n_tensors = n_tens; + rsp.n_ops = n_ops; dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); if (err != 0) { FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index bac06693d81..a0c265132c8 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -3017,6 +3017,10 @@ int op_matmul(struct htp_ops_context * octx) { const int act_stride = (int)(src1->nb[1] / sizeof(float)); const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + if (src0->type == HTP_TYPE_F16) { if (is_batched) { hmx_matmul_w16a32_batched_params_t batch_params = { From 641998f558afb6dae907e86ef0a44995b8a00592 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Thu, 23 Apr 2026 19:32:59 -0400 Subject: [PATCH 185/249] fix(shader): handle the buffer aliasing for rms fuse (llama/22266) --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 14 ++++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 ++++-- .../ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl | 17 ++++++++++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index efc5b8c97a7..449eae808e4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -197,11 +197,12 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { /** RMS_NORM + MUL **/ struct ggml_webgpu_rms_norm_mul_pipeline_key { - bool inplace; - bool src_overlap; + bool inplace; // rn_src == dst + bool overlap; // mul_src == dst + bool src_overlap; // rn_src == mul_src bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { - return inplace == other.inplace && src_overlap == other.src_overlap; + return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; @@ -209,6 +210,7 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } @@ -556,7 +558,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -1878,6 +1880,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rms_norm_mul_pipeline_key key = {}; key.inplace = context.inplace; + key.overlap = context.overlap; key.src_overlap = context.src_overlap; auto it = rms_norm_mul_pipelines.find(key); @@ -1892,6 +1895,9 @@ class ggml_webgpu_shader_lib { if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; } else if (key.src_overlap) { defines.push_back("SRC_OVERLAP"); variant += "_src_overlap"; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index bcca2bd4627..acc486cfdda 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2071,8 +2071,9 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); uint32_t offset_merged_rn_src = 0; @@ -2116,7 +2117,7 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context std::vector entries; - if (inplace) { + if (inplace || overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); } else if (src_overlap) { @@ -2136,6 +2137,7 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.inplace = inplace; + shader_lib_ctx.overlap = overlap; shader_lib_ctx.src_overlap = src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl index 71f063b51aa..74aaa2753ae 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -1,4 +1,4 @@ -#ifdef INPLACE +#ifdef OVERLAP @group(0) @binding(0) var rn_src: array; @@ -13,6 +13,21 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; } +#elif INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + #elif SRC_OVERLAP @group(0) @binding(0) From 23921d5a695262bdf9bdb34300f179aa97ae7a1e Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Fri, 24 Apr 2026 09:39:13 +0800 Subject: [PATCH 186/249] hexagon: add SOLVE_TRI op (llama/21974) * hexagon: add SOLVE_TRI op * ggml: fix TODO description for solve_tri * hexagon: rm unused variable/function warnings * hexagon: chunk vs batch processingfor better thread utilization * hexagon: vectorize partial f32 loads * hexagon: move HVX f32 add/sub/mul wrappers to hvx-base.h --------- Co-authored-by: Todor Boinovski --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 39 +++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 2 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 24 ++ ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/solve-tri-ops.c | 267 ++++++++++++++++++++++ 7 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/solve-tri-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 955903418b6..0d9b5e289bb 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2693,6 +2693,39 @@ static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // A + const struct ggml_tensor * src1 = op->src[1]; // B + const struct ggml_tensor * dst = op; // X + + if (!src0 || !src1) { + return false; + } + + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (src0->ne[0] != src0->ne[1]) { + return false; + } + + if (src0->ne[1] != src1->ne[1]) { + return false; + } + + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return false; + } + + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2731,7 +2764,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; - + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3277,6 +3310,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_diag(sess, op); break; + case GGML_OP_SOLVE_TRI: + supp = ggml_hexagon_supported_solve_tri(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index b1ae60a9c43..8bd528478ba 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -36,6 +36,7 @@ add_library(${HTP_LIB} SHARED cumsum-ops.c fill-ops.c diag-ops.c + solve-tri-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index f8c89211aed..d704fedee9d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -103,5 +103,6 @@ int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); +int op_solve_tri(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 56d7b398d10..4397245c5b8 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -82,7 +82,7 @@ enum htp_op_code { HTP_OP_CUMSUM, HTP_OP_FILL, HTP_OP_DIAG, - + HTP_OP_SOLVE_TRI, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ed6026e762a..d0926dedd28 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -256,6 +256,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); +} + #else static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) @@ -273,6 +285,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_vmpy_VhfVhf(a, b); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vsub_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vmpy_VsfVsf(a, b); +} + #endif // __HVX_ARCH__ < 79 #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 088434a63e9..db277a25e5a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -573,6 +573,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_DIAG: return op_diag(octx); + case HTP_OP_SOLVE_TRI: + return op_solve_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/solve-tri-ops.c b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c new file mode 100644 index 00000000000..ae8e1a50495 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" + +struct htp_solve_tri_context { + struct htp_ops_context * octx; + uint32_t jobs_per_thread; + uint32_t total_jobs; + uint32_t k_chunks; + uint32_t col_block; +}; + +static inline void solve_tri_row_scalar(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + for (uint32_t col = col0; col < col0 + coln; ++col) { + float sum = 0.0f; + for (uint32_t t = 0; t < row; ++t) { + sum += A_row[t] * X[t * k + col]; + } + X[row * k + col] = (B_row[col] - sum) * inv_diag; + } +} + +static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) { + HVX_Vector v = *((const HVX_UVector *) src); + HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float)); + return Q6_V_vmux_QVV(mask, v, Q6_V_vzero()); +} + +static inline void solve_tri_row_hvx(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + const bool full = (coln == VLEN_FP32); + + HVX_Vector sum_v = Q6_V_vzero(); + for (uint32_t t = 0; t < row; ++t) { + const float a = A_row[t]; + const float * x_row_col = X + t * k + col0; + + HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln); + HVX_Vector a_v = hvx_vec_splat_f32(a); + sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v)); + } + + const float * b_row_col = B_row + col0; + float * x_out_col = X + row * k + col0; + + HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln); + HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag); + + HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v); + hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v); +} + +// Batch-level thread: each job is one full batch. +static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_full = (k / col_block) * col_block; + + const uint32_t start_batch = sctx->jobs_per_thread * ith; + const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t batch = start_batch; batch < end_batch; ++batch) { + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + uint32_t col0 = 0; + for (; col0 < k_full; col0 += col_block) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag); + } + + if (col0 < k) { + const uint32_t coln = k - col0; + if (coln >= 8) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n", + ith, nth, n, n, k, n, start_batch, end_batch, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// Chunk-level thread: each job is one (batch, col_chunk) pair. +static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t start_job = sctx->jobs_per_thread * ith; + const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t job = start_job; job < end_job; ++job) { + const uint32_t batch = job / sctx->k_chunks; + const uint32_t chunk = job - batch * sctx->k_chunks; + + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const uint32_t col0 = chunk * sctx->col_block; + const uint32_t coln = MIN(sctx->col_block, k - col0); + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + const bool use_hvx = (coln >= 8); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + if (use_hvx) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n", + ith, nth, n, n, k, n, start_job, end_job, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_solve_tri(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // left=true, lower=true, uni=false only + if (src0->ne[0] != src0->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[1] != src1->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || + dst->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t k = src1->ne[0]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_chunks = (k + col_block - 1) / col_block; + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const bool batched = total_batches >= (uint32_t) octx->n_threads; + + FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched); + + if (batched) { + // Batch-level parallelism + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_jobs = total_batches, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads); + } else { + // Chunk-level parallelism + const uint32_t total_jobs = total_batches * k_chunks; + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1)); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_jobs + n_threads - 1) / n_threads, + .total_jobs = total_jobs, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads); + } + + return HTP_STATUS_OK; +} From dfb8b68799f3aa4781c11007ecfc82fc146728eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Apr 2026 11:02:00 +0300 Subject: [PATCH 187/249] ggml : minor coding style (llama/22308) --- ggml/src/ggml.c | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eda041f4518..54d3eae3e4d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7656,7 +7656,7 @@ size_t ggml_quantize_chunk( int64_t nrows, int64_t n_per_row, const float * imatrix) { - const int64_t n = (int64_t) nrows * n_per_row; + const int64_t n = nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); @@ -7673,21 +7673,21 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -7752,9 +7752,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { } bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } From 07d6db39e5f659a048e07e28f19e4439ebe1625e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Apr 2026 13:56:03 +0300 Subject: [PATCH 188/249] metal : print GPU description (llama/22318) --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index f17f7e2e0ce..27b78c5e6d7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -814,7 +814,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { } // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc); // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf From 6576c4da90f5a8b1662697a2b73442276657677c Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Sat, 25 Apr 2026 00:21:33 +0800 Subject: [PATCH 189/249] hexagon: use DIRID 13 in libggml-htp.inf for modern InfVerif (llama/22306) --- ggml/src/ggml-hexagon/libggml-htp.inf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 360d8b1228e..39cefcdda38 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -8,7 +8,7 @@ CatalogFile = libggml-htp.cat PnpLockDown = 1 [DestinationDirs] -Drivers_Dir = 6 +Drivers_Dir = 13 [SourceDisksNames] 1 = %DiskId% From 35d679a4f8f51833e6d25b0f748632ba888d3d7b Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 24 Apr 2026 10:39:09 -0700 Subject: [PATCH 190/249] ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix (llama/22199) * ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 326 +++++++++-------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 193 ++++++---- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 29 ++ .../wgsl-shaders/flash_attn_tile.wgsl | 330 ++++++++++++++++++ .../wgsl-shaders/flash_attn_vec_blk.wgsl | 2 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 321 ++++++++--------- 6 files changed, 809 insertions(+), 392 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 449eae808e4..e492c2123a4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -436,19 +436,27 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ +enum ggml_webgpu_flash_attn_path : uint32_t { + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, +}; + struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; + bool kv_overlap; bool has_mask; bool has_sinks; bool uses_logit_softcap; + uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; @@ -459,39 +467,70 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.path); return seed; } }; struct ggml_webgpu_flash_attn_decisions { - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; }; -struct ggml_webgpu_flash_attn_vec_decisions { - uint32_t kv_tile = 0; - uint32_t wg_size = 0; -}; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; + +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || + key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + return 2u; + case 96: + return 4u; + default: + return 1u; + } +} inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_shader_lib_context & context, + uint32_t path) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; - const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + bool kv_direct = false; + if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + kv_direct_align = context.sg_mat_k; + } + kv_direct = (context.src1->type == GGML_TYPE_F16) && + (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + } ggml_webgpu_flash_attn_pipeline_key key = {}; key.kv_type = context.src1->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; + key.kv_overlap = context.src_overlap; key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + key.path = path; return key; } @@ -554,8 +593,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; + const size_t limit_bytes = context.wg_mem_limit_bytes; + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_granularity = context.sg_mat_n; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + kv_granularity = std::max(1u, context.max_subgroup_size); + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + q_tile = 1u; + kv_granularity = 8u; + } const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; @@ -568,23 +615,90 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; + return (max_kv_tile / kv_granularity) * kv_granularity; } -inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; +inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( + const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + ggml_webgpu_flash_attn_decisions decisions = {}; + const size_t alignment = std::max(1u, storage_offset_alignment); + const auto * K = context.src1; + const auto * V = context.src2; + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; + }; + + const uint32_t k_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && + (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && + V->type == GGML_TYPE_F16 && f16_vec4_aligned && + (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + decisions.kv_direct = key.kv_direct; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + if (decisions.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= 8u; + } + } + return decisions; + } + + decisions.q_tile = + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; + decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) : + std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); + decisions.kv_tile = + std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + } + + if (decisions.kv_direct) { + GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::max(1u, context.max_subgroup_size) : + context.sg_mat_n; } } - return kv_tile; + return decisions; } /** Matrix Multiplication **/ @@ -821,8 +935,6 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; - std::unordered_map - flash_attn_vec_pipelines; std::unordered_map @@ -2044,14 +2156,19 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + const ggml_webgpu_flash_attn_decisions decisions = + ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } std::vector defines; - std::string variant = "flash_attn"; + std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : + "flash_attn"; switch (key.kv_type) { case GGML_TYPE_F32: @@ -2073,7 +2190,12 @@ class ggml_webgpu_shader_lib { if (key.has_mask) { defines.push_back("MASK"); - variant += "_mask"; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("BLK"); + variant += "_mask_blk"; + } else { + variant += "_mask"; + } } if (key.has_sinks) { defines.push_back("SINKS"); @@ -2087,6 +2209,10 @@ class ggml_webgpu_shader_lib { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); @@ -2094,129 +2220,37 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - auto decisions = std::make_shared(); - decisions->q_tile = context.sg_mat_m; - - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - - if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } + const char * shader_src = wgsl_flash_attn; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("KV_GRANULARITY=8"); + defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); + shader_src = wgsl_flash_attn_vec_split; + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + shader_src = wgsl_flash_attn_tile; + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); + variant += "_tile"; + } else { + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - decisions->kv_tile = kv_tile; - decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + auto pipeline_decisions = std::make_shared(decisions); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); - pipeline.context = decisions; + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); + pipeline.context = pipeline_decisions; flash_attn_pipelines[key] = pipeline; return flash_attn_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_vec_pipelines.find(key); - if (it != flash_attn_vec_pipelines.end()) { - return it->second; - } - - std::vector defines; - std::string variant = "flash_attn_vec"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - if (key.has_mask) { - defines.push_back("MASK"); - defines.push_back("BLK"); - variant += "_mask_blk"; - } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); - - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - defines.push_back("Q_TILE=1"); - - auto decisions = std::make_shared(); - decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); - decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - uint32_t vec_ne = 1u; - - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - vec_ne = 2u; - break; - case 96: - vec_ne = 4u; - break; - default: - break; - } - } - - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - - webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); - pipeline.context = decisions; - flash_attn_vec_pipelines[key] = pipeline; - return flash_attn_vec_pipelines[key]; - } - - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; - key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + key.kv_tile = kv_tile; auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index acc486cfdda..7ed6fdd1625 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } -static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, - const ggml_tensor * Q, - const ggml_tensor * K, - const ggml_tensor * V) { - const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint32_t k_offset_elems = - (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - - return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); -} - static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); @@ -1567,7 +1550,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#ifndef __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1585,13 +1567,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + if (kv_overlap) { + const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); + const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); + const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); + const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); + kv_bind_offset = std::min(k_bind_offset, v_bind_offset); + kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset; + offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type)); + offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type)); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + offset_k, + offset_v, has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1619,10 +1617,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, }; std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; - uint32_t binding_index = 3; + if (kv_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + } + uint32_t binding_index = kv_overlap ? 2u : 3u; if (has_mask) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } @@ -1638,25 +1641,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.src3 = mask; shader_lib_ctx.src4 = sinks; shader_lib_ctx.dst = dst; + shader_lib_ctx.src_overlap = kv_overlap; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); - webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : - ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); - if (!use_vec) { - auto * decisions = static_cast(pipeline.context.get()); + if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } - auto * decisions = static_cast(pipeline.context.get()); - wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; uint32_t blk_nblk0 = 0; @@ -1695,10 +1698,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, tmp_bind_size = tmp_size_bytes; scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + // nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region. + tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT; tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } webgpu_pipeline blk_pipeline; @@ -1713,7 +1718,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask @@ -1745,12 +1750,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector split_entries = { ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), - ggml_webgpu_tensor_binding_size(ctx, K)), - ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), - ggml_webgpu_tensor_binding_size(ctx, V)), }; - uint32_t split_binding_index = 3; + if (kv_overlap) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), + ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K))); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), + ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V))); + } + uint32_t split_binding_index = kv_overlap ? 2u : 3u; if (has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), @@ -1820,7 +1832,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#endif // __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2710,11 +2721,7 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, case GGML_OP_MUL_MAT_ID: return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: -#ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); -#else - return std::nullopt; -#endif case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -3257,13 +3264,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { - const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t kv_tile = decisions.kv_tile; const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3283,6 +3296,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const size_t tmp_size_bytes = ROUNDUP_POW2( (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; } if (mask != nullptr) { const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); @@ -3431,12 +3446,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). // NVIDIA GPUs typically report square configs (e.g. 16x16x16), // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). // The shaders are already parameterized to handle any M/N/K dimensions. - bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; @@ -3450,8 +3465,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } } } - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. @@ -3499,12 +3514,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // Enable Dawn-specific toggles to increase native performance // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; wgpu::DawnTogglesDescriptor deviceTogglesDesc; deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.enabledToggleCount = 3; deviceTogglesDesc.disabledToggles = deviceDisabledToggles; deviceTogglesDesc.disabledToggleCount = 1; @@ -3782,33 +3797,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { -#ifndef __EMSCRIPTEN__ - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + src2->type == src1->type && op->type == GGML_TYPE_F32; + if (!supports_op) { break; } - // Head dimensions must be divisible by subgroup matrix dimensions - if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || - src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = op->src[3]; + shader_lib_ctx.src4 = op->src[4]; + shader_lib_ctx.dst = const_cast(op); + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } - // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); - if (min_bytes > limit_bytes) { + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } - supports_op = src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || - src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; -#endif + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = false; + break; + } + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index aa2d2e54db9..6d5d69fb8de 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -138,26 +138,55 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array; +#define V K +#else @group(0) @binding(1) var K: array; @group(0) @binding(2) var V: array; +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var mask: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var sinks: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 #else #define DST_BINDING 3 #define PARAMS_BINDING 4 #endif +#endif @group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl new file mode 100644 index 00000000000..37ea23b80c8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -0,0 +1,330 @@ +enable f16; +enable subgroups; + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 +#define KV_STAGE_STRIDE 64 +#define Q_TILE 4 +#define KV_TILE 64 +#define WG_SIZE 128 + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + q_per_kv: u32, + + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array>; +#define V K +#else +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; +#endif + +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; +const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; +const V_CHUNKS: u32 = HEAD_DIM_V / 4u; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; + +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + if (subgroup_size == 0u || num_subgroups < Q_TILE) { + return; + } + + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + let global_q_row = q_row_start + subgroup_id; + let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q; + +#ifdef MASK + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); + + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_tile_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_tile_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col] * params.scale, + head_q_row < params.seq_len_q)); + } + + workgroupBarrier(); + + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + var out_regs: array, OUT_REGS_PER_LANE>; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] = vec4(0.0); + } + + let q_base = subgroup_id * HEAD_DIM_QK; + let subgroup_p_offset = subgroup_id * KV_TILE; + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); + let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size); + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + var local_scores: array; + for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) { + local_scores[slot] = FLOAT_MIN; + } + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = k4.x; + kv_shmem[kv_off + 1u] = k4.y; + kv_shmem[kv_off + 2u] = k4.z; + kv_shmem[kv_off + 3u] = k4.w; + } + + workgroupBarrier(); + + var local_max = FLOAT_MIN; + if (row_active) { + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (kv_local >= kv_count) { + continue; + } + + let global_k_row = kv_tile + kv_local; + var dot_val = 0.0; + for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { + let q_off = q_base + chunk * 4u; + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + dot_val += dot(qv, kv); + } +#ifdef LOGIT_SOFTCAP + dot_val = params.logit_softcap * tanh(dot_val); +#endif +#ifdef MASK + let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row; + dot_val += slope * f32(mask[mask_idx]); +#endif + local_scores[slot] = dot_val; + local_max = max(local_max, dot_val); + } + } + + let tile_max = subgroupMax(local_max); + let new_max = max(row_max, tile_max); + let cur_exp = exp(row_max - new_max); + exp_sum *= cur_exp; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= cur_exp; + } + + var local_sum = 0.0; + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (row_active && kv_local < kv_count) { + let p = exp(local_scores[slot] - new_max); + p_shmem[subgroup_p_offset + kv_local] = p; + local_sum += p; + } + } + + workgroupBarrier(); + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = v4.x; + kv_shmem[kv_off + 1u] = v4.y; + kv_shmem[kv_off + 2u] = v4.z; + kv_shmem[kv_off + 3u] = v4.w; + } + + workgroupBarrier(); + + let tile_sum = subgroupAdd(local_sum); + exp_sum += tile_sum; + row_max = new_max; + + if (row_active) { + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + + var acc = out_regs[reg_idx]; + for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { + let p = p_shmem[subgroup_p_offset + kv_local]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let v4 = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + acc += p * v4; + } + out_regs[reg_idx] = acc; + } + } + + workgroupBarrier(); + } + +#ifdef SINKS + if (row_active) { + let sink_score = sinks[params.offset_sinks + head_idx]; + let sink_max = max(row_max, sink_score); + let sink_scale = exp(row_max - sink_max); + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= sink_scale; + } + exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max); + row_max = sink_max; + } +#endif + + if (row_active) { + let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base = dst_global_offset + subgroup_id * dst2_stride; + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + let dst_vec_index = (row_base + chunk * 4u) >> 2u; + dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 61107c6a985..b4f7c16c35d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -15,7 +15,7 @@ struct Params { nblk1: u32, }; -@group(0) @binding(0) var mask: array; +@group(0) @binding(0) var mask: array; @group(0) @binding(1) var blk: array; @group(0) @binding(2) var params: Params; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index a52575871ae..b1e234784a8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -1,8 +1,6 @@ -diagnostic(off, chromium.subgroup_matrix_uniformity); diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 @@ -13,19 +11,14 @@ enable chromium_experimental_subgroup_matrix; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 - -#define SG_MAT_M 8 -#define SG_MAT_N 8 -#define SG_MAT_K 8 - -#define Q_TILE SG_MAT_M +#define KV_GRANULARITY 8 #define KV_TILE 16 #define WG_SIZE 64 #ifndef VEC_NE #define VEC_NE 4u #endif -#define KV_BLOCKS (KV_TILE / SG_MAT_N) +#define KV_BLOCKS (KV_TILE / KV_GRANULARITY) #define BLOCK_SIZE 32 #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) @@ -97,6 +90,14 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#define V K +#else #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; #else @@ -107,7 +108,22 @@ struct Params { #else @group(0) @binding(2) var V: array>; #endif +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; #ifdef BLK @@ -120,7 +136,21 @@ struct Params { #define DST_BINDING 6 #define PARAMS_BINDING 7 #endif +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#ifdef BLK +#define BLK_BINDING 3 +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else @group(0) @binding(3) var mask: array; #ifdef BLK #define BLK_BINDING 4 @@ -132,16 +162,30 @@ struct Params { #define DST_BINDING 5 #define PARAMS_BINDING 6 #endif +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var sinks: array; #define TMP_BINDING 4 #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif +#else +#ifdef KV_OVERLAP +#define TMP_BINDING 2 +#define DST_BINDING 3 +#define PARAMS_BINDING 4 #else #define TMP_BINDING 3 #define DST_BINDING 4 #define PARAMS_BINDING 5 #endif +#endif #ifdef BLK @group(0) @binding(BLK_BINDING) var blk: array; @@ -153,7 +197,7 @@ struct Params { // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -var q_shmem: array; +var q_shmem: array; #ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); @@ -161,31 +205,27 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var kv_shmem: array; #endif -var o_shmem: array; +var o_shmem: array; #ifdef MASK // storage for mask values -var mask_shmem: array; +var mask_shmem: array; #endif // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; // Storage for row max and exp sum during online softmax -var row_max_shmem: array; -var exp_sum_shmem: array; -var blk_state_wg: u32; - -fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { +fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + f32(inter_shmem[kv_idx]) * params.scale, kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -199,19 +239,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(subgroup_size) subgroup_size: u32, @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { + // Vec path processes exactly one query row per workgroup, so subgroup 0 can + // keep the running softmax state in private storage. + var row_max = FLOAT_MIN; + var exp_sum = 0.0; - // initialize row max for online softmax - for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { - row_max_shmem[i] = FLOAT_MIN; - exp_sum_shmem[i] = 0.0; - } - - for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) { o_shmem[i] = 0.0; } // workgroups per head/batch - let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_head = params.seq_len_q; let wg_per_batch = wg_per_head * params.n_heads; let dst2_stride = HEAD_DIM_V * params.n_heads; @@ -235,9 +273,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; - // starting Q row for this workgroup + // Vec path handles one Q row per workgroup. let wg_in_head = wg_in_batch % wg_per_head; - let q_row_start = wg_in_head * Q_TILE; + let q_row_start = wg_in_head; #ifdef MASK // mask offset @@ -248,21 +286,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let has_bias = params.max_bias > 0.0; let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); - // load q tile into shared memory - for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let q_row = elem_idx / HEAD_DIM_QK; - let q_col = elem_idx % HEAD_DIM_QK; - let head_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + // load the single Q row into shared memory + for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { + let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; q_shmem[elem_idx] = f16(select( 0.0, - Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + Q[global_q_row_offset + elem_idx], + q_row_start < params.seq_len_q)); } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { #ifdef BLK - let q_blk = q_row_start / Q_TILE; + let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; @@ -270,13 +305,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #else let blk_state_local = 1u; #endif - if (local_id.x == 0u) { - blk_state_wg = blk_state_local; - } - workgroupBarrier(); - let blk_state = blk_state_wg; + let blk_state = blk_state_local; let skip_tile = blk_state == 0u; - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = f16(0.0); } @@ -360,20 +391,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let num_of_threads = subgroup_size / VEC_NE; let tx = sg_inv_id % num_of_threads; let ty = sg_inv_id / num_of_threads; - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - continue; - } - let local_q_row_offset = q_tile_row * HEAD_DIM_QK; - + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { let kv_idx = kv_base + ty; var partial_sum: f32 = 0.0; let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { - let q_off = local_q_row_offset + i * 4u; + let q_off = i * 4u; let qv = vec4( f32(q_shmem[q_off + 0u]), @@ -410,8 +435,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); if (tx == 0u && kv_valid) { - let dst_idx = q_tile_row * KV_TILE + kv_idx; - inter_shmem[dst_idx] = f16(sum_bcast); + inter_shmem[kv_idx] = f16(sum_bcast); } } } @@ -422,13 +446,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let apply_mask = !skip_tile && (blk_state != 2u); if (apply_mask) { // load mask tile into shared memory for this KV block - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - let mask_row = elem_idx / KV_TILE; - let mask_col = elem_idx % KV_TILE; - let global_q_row = q_row_start + mask_row; - let global_k_col = kv_tile + mask_col; - let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; - let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + let global_k_col = kv_tile + elem_idx; + let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + global_k_col; mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); } } @@ -439,50 +460,40 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); // online softmax - if (!skip_tile) { - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - var prev_max = row_max_shmem[q_tile_row]; - var final_max = prev_max; - // pass 1: compute final max across the full KV tile in chunks - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; - let softmax_term = select(FLOAT_MIN, - calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), - kv_valid); - final_max = subgroupMax(max(final_max, softmax_term)); - } + if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } - var total_exp_term: f32 = 0.0; - // pass 2: compute exp sum and write P using final_max - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); - let cur_p = select(0.0, - exp(softmax_term - final_max), - kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); - total_exp_term += subgroupAdd(cur_p); - if (kv_idx < KV_TILE) { - inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); - } + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx] = f16(cur_p); } + } - let cur_exp = exp(prev_max - final_max); + let cur_exp = exp(prev_max - final_max); - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = final_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; - } + row_max = final_max; + exp_sum = exp_sum * cur_exp + total_exp_term; - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); - } + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); } } @@ -562,15 +573,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); if (!skip_tile) { - // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile let ne_threads : u32 = VEC_NE; let nl_threads = max(1u, subgroup_size / ne_threads); let tx_pv = sg_inv_id % nl_threads; let ty_pv = sg_inv_id / nl_threads; - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { var lo = vec4(0.0, 0.0, 0.0, 0.0); for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { @@ -580,7 +589,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } - let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); + let p = f32(inter_shmem[kv_idx]); #ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); @@ -621,11 +630,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (ty_pv == 0u) { let elem_base = vec_col * 4u; - let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; - o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); - o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); - o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); - o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); + o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); + o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); + o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); } } } @@ -637,70 +645,46 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef SINKS // Sinks are global terms and must be applied exactly once across split workgroups. - if (iwg == 0u) { - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - var prev_max = row_max_shmem[q_tile_row]; - - // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum - let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); - let new_max = subgroupMax(max(prev_max, sink_val)); - let max_exp = exp(prev_max - new_max); - let sink_exp = exp(sink_val - new_max); - - let sink_exp_sum = subgroupAdd(sink_exp); - - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = new_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); - } + if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + row_max = new_max; + exp_sum = exp_sum * max_exp + sink_exp_sum; + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); } - workgroupBarrier(); } + workgroupBarrier(); #endif let rows_per_batch = params.n_heads * params.seq_len_q; - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { break; } - + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { if (params.nwg == 1u) { - let exp_sum = exp_sum_shmem[q_tile_row]; let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); - let row_base: u32 = - params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride + + head_idx * HEAD_DIM_V; for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - let v = vec4( - f32(o_shmem[i0]) * scale, - f32(o_shmem[i1]) * scale, - f32(o_shmem[i2]) * scale, - f32(o_shmem[i3]) * scale + f32(o_shmem[elem_base + 0u]) * scale, + f32(o_shmem[elem_base + 1u]) * scale, + f32(o_shmem[elem_base + 2u]) * scale, + f32(o_shmem[elem_base + 3u]) * scale ); let dst_vec_index: u32 = (row_base + elem_base) >> 2u; dst[dst_vec_index] = v; } } else { - let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; @@ -708,21 +692,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - let tbase = tmp_row_data_base + elem_base; - tmp[tbase + 0u] = f32(o_shmem[i0]); - tmp[tbase + 1u] = f32(o_shmem[i1]); - tmp[tbase + 2u] = f32(o_shmem[i2]); - tmp[tbase + 3u] = f32(o_shmem[i3]); + tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]); + tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]); + tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]); + tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]); } if (sg_inv_id == 0u) { - tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; - tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 0u] = exp_sum; + tmp[tmp_row_stats_base + 1u] = row_max; } } } From c546b0b1bc9a0b6dde7b330986800bf8183eda14 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:55:17 -0500 Subject: [PATCH 191/249] Hexagon: Bump HMX Frequency to Max Corner (llama/22334) * hexagon: bump HMX freq to max corner * hex-mm: fix error in log msg --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 2 +- ggml/src/ggml-hexagon/htp/main.c | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index dbca8220fab..05e3c6c2b0f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -1683,7 +1683,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); // initialize eye tile (32x32 identity matrix) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index db277a25e5a..62942f6384c 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -101,6 +101,24 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } + { + // Set HMX clock + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) &ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Error setting HMX clock."); + return err; + } + } + return AEE_SUCCESS; } From c235b05d8a0044b771cc06d975128415810cc002 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 24 Apr 2026 23:18:15 -0700 Subject: [PATCH 192/249] ggml-webgpu: support for SSM_SCAN and disable set_rows error checking (llama/22327) * Implement ssm_scan * Remove blocking in graph_compute and check for set rows * Fix bindings * Update op support --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 72 ++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 90 +++++++++- .../ggml-webgpu/wgsl-shaders/ssm_scan.wgsl | 168 ++++++++++++++++++ 3 files changed, 328 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e492c2123a4..16ebc32cbc7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -98,6 +98,29 @@ struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t tokens_per_wg; }; +struct ggml_webgpu_ssm_scan_pipeline_key { + int type; + int d_state; + + bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { + return type == other.type && d_state == other.d_state; + } +}; + +struct ggml_webgpu_ssm_scan_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.d_state); + return seed; + } +}; + +struct ggml_webgpu_ssm_scan_shader_decisions { + uint32_t wg_size; + uint32_t tokens_per_tile; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -921,6 +944,8 @@ class ggml_webgpu_shader_lib { solve_tri_pipelines; // type std::unordered_map ssm_conv_pipelines; // type/vectorized + std::unordered_map + ssm_scan_pipelines; // type/d_state std::unordered_map @@ -1433,6 +1458,53 @@ class ggml_webgpu_shader_lib { return ssm_conv_pipelines[key]; } + webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_scan_pipeline_key key = {}; + key.type = context.dst->type; + key.d_state = (int) context.src0->ne[0]; + + auto it = ssm_scan_pipelines.find(key); + if (it != ssm_scan_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_scan"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_scan shader"); + } + + const uint32_t wg_size = (uint32_t) key.d_state; + + constexpr uint32_t tokens_per_tile = 4u; + + defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u"); + defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u"); + + if (context.supports_subgroups) { + defines.push_back("USE_SUBGROUP_REDUCTION"); + variant += "_sg_reduce"; + } else { + variant += "_wg_reduce"; + } + + variant += "_d" + std::to_string(key.d_state); + + auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tokens_per_tile = tokens_per_tile; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_scan_pipelines[key] = pipeline; + return ssm_scan_pipelines[key]; + } + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_gated_delta_net_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 7ed6fdd1625..bcec20c1a11 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1115,6 +1115,80 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * src6, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + + (uint32_t) src3->ne[0], + (uint32_t) (src3->nb[1] / ggml_type_size(src3->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) (src5->nb[1] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[2] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[3] / ggml_type_size(src5->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) src4->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + (uint32_t) ggml_nelements(src1), + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst), + }; + + const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x; + uint32_t wg_y; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2764,6 +2838,9 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_solve_tri(ctx, src0, src1, node); case GGML_OP_SSM_CONV: return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_SSM_SCAN: + return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6], + node); case GGML_OP_GATED_DELTA_NET: return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: @@ -2822,7 +2899,10 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & } #endif +// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support +// models that would require it right now. static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { +#ifdef GGML_WEBGPU_CHECK_SET_ROWS wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, ctx->set_rows_host_error_buf.GetSize()); @@ -2835,6 +2915,10 @@ static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); } ctx->set_rows_host_error_buf.Unmap(); +#else + GGML_UNUSED(ctx); + GGML_UNUSED(num_inflight_batches); +#endif } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -2920,8 +3004,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } - ggml_backend_webgpu_wait_queue(ctx->global_ctx); - WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3941,6 +4023,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_SSM_SCAN: + supports_op = op->type == GGML_TYPE_F32 && + src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + break; case GGML_OP_GATED_DELTA_NET: { const uint32_t s_v = (uint32_t) src2->ne[0]; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl new file mode 100644 index 00000000000..64324738591 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -0,0 +1,168 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif + +struct Params { + offset_s: u32, + offset_x: u32, + offset_dt: u32, + offset_A: u32, + offset_B: u32, + offset_C: u32, + offset_ids: u32, + offset_dst: u32, + + stride_s1: u32, + stride_s2: u32, + stride_s3: u32, + + stride_x1: u32, + stride_x2: u32, + stride_x3: u32, + + stride_dt1: u32, + stride_dt2: u32, + + a_ne0: u32, + stride_A1: u32, + + stride_B1: u32, + stride_B2: u32, + stride_B3: u32, + + stride_C1: u32, + stride_C2: u32, + stride_C3: u32, + + d_state: u32, + d_inner: u32, + n_head: u32, + n_group: u32, + n_seq_tokens: u32, + n_seqs: u32, + + y_elems: u32, +}; + +@group(0) @binding(0) var s_in: array; +@group(0) @binding(1) var x: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var B: array; +@group(0) @binding(5) var C: array; +@group(0) @binding(6) var ids: array; +@group(0) @binding(7) var dst: array; +@group(0) @binding(8) var params: Params; + +var shared_x_dt: array; +var shared_dtsp: array; +var shared_reduce: array; + +fn reduce_base(token_in_tile: u32) -> u32 { + return token_in_tile * WG_SIZE; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32 +#endif +) { + let tid = local_id.x; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let i1 = wg_linear % params.d_inner; + let head_seq = wg_linear / params.d_inner; + let ir = head_seq % params.n_head; + let i3 = head_seq / params.n_head; + + let state_slot = u32(ids[params.offset_ids + i3]); + let g = ir / (params.n_head / params.n_group); + + let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3; + var s_prev = s_in[s_idx]; + + let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1]; + + for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) { + if (tid < TOKENS_PER_TILE) { + let token = token_base + tid; + if (token < params.n_seq_tokens) { + let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3; + let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2; + let dt0 = dt[dt_idx]; + let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); + shared_dtsp[tid] = dtsp; + shared_x_dt[tid] = x[x_idx] * dtsp; + } + } + + workgroupBarrier(); + + for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) { + let token = token_base + token_in_tile; + if (token >= params.n_seq_tokens) { + break; + } + + let x_dt = shared_x_dt[token_in_tile]; + let dA = exp(shared_dtsp[token_in_tile] * A0); + let reduce_idx = reduce_base(token_in_tile) + tid; + + let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; + let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; + let s = s_prev * dA + B[b_idx] * x_dt; + s_prev = s; + +#ifdef USE_SUBGROUP_REDUCTION + let subgroup_partial = subgroupAdd(s * C[c_idx]); + if (subgroup_invocation_id == 0u) { + shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; + } +#else + shared_reduce[reduce_idx] = s * C[c_idx]; +#endif + + workgroupBarrier(); + +#ifdef USE_SUBGROUP_REDUCTION + if (tid == 0u) { + var sum = 0.0; + for (var sg = 0u; sg < num_subgroups; sg++) { + sum += shared_reduce[reduce_base(token_in_tile) + sg]; + } + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = sum; + } +#else + for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) { + if (tid < stride) { + shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = shared_reduce[reduce_base(token_in_tile)]; + } +#endif + + workgroupBarrier(); + } + } + + let state_idx = + params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) + + i3 * (params.d_state * params.d_inner * params.n_head); + dst[state_idx] = s_prev; +} From 6296fd5a904edbd9785a9e8e06d38564e3c70b49 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sat, 25 Apr 2026 14:20:14 +0800 Subject: [PATCH 193/249] Optimize Q4_0 mul_mat for Arc770, add scripts (llama/22291) * opt arc770 for Q4_0 * add for Q4_0 * update the script * add help script for windows * update guide * fix format issue * convert from dos to unix for format issue * fix missed -sm parameter --- ggml/src/ggml-sycl/common.hpp | 2 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 10 ++++- ggml/src/ggml-sycl/sycl_hw.cpp | 72 +++++++++++++++++++++++++++----- ggml/src/ggml-sycl/sycl_hw.hpp | 24 ++++++++--- 4 files changed, 90 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 0101b27640a..5abf2290651 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -224,7 +224,7 @@ struct sycl_device_info { // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support size_t total_vram; - //sycl_hw_info hw_info; \\ device id and aarch, currently not used + sycl_hw_info hw_info; optimize_feature opt_feature; }; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 36923160d72..1eead625e76 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -104,6 +104,7 @@ static ggml_sycl_device_info ggml_sycl_init() { info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); + info.devices[i].hw_info = get_device_hw_info(&device); } @@ -3703,9 +3704,16 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization // is enabled takes precedence over DMMV, the current if-else implementation // requires disabling DMMV if both conditions are met + if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + // Arc770 get benefit with Q4_0 by skipping it. + if (!(ggml_sycl_info().devices[ctx.device].hw_info.arch == + gpu_arch::intel_gpu_acm_g10 && + src0->type == GGML_TYPE_Q4_0)) { + use_dequantize_mul_mat_vec = + use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp index 7041140034b..03b0c37a3cd 100644 --- a/ggml/src/ggml-sycl/sycl_hw.cpp +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -1,15 +1,67 @@ #include "sycl_hw.hpp" -// TODO: currently not used -/* -sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { - sycl_hw_info res; - int32_t id = device_ptr->get_info(); - res.device_id = id; +using namespace std; - syclex::architecture arch = device_ptr->get_info(); - res.arch = arch; +/*defined in +* /opt/intel/oneapi/compiler/latest/include/sycl/ext/oneapi/experimental/device_architecture.def +*/ +static map> arch2name = { + {gpu_arch::intel_gpu_bdw, {"intel_gpu_bdw", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_skl, {"intel_gpu_skl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_kbl, {"intel_gpu_kbl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cfl, {"intel_gpu_cfl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_apl, {"intel_gpu_apl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_glk, {"intel_gpu_glk", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_whl, {"intel_gpu_whl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_aml, {"intel_gpu_aml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cml, {"intel_gpu_cml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_icllp, {"intel_gpu_icllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_ehl, {"intel_gpu_ehl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_tgllp, {"intel_gpu_tgllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_rkl, {"intel_gpu_rkl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_s, {"intel_gpu_adl_s", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_p, {"intel_gpu_adl_p", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_n, {"intel_gpu_adl_n", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_dg1, {"intel_gpu_dg1", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g10, {"intel_gpu_acm_g10", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g11, {"intel_gpu_acm_g11", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g12, {"intel_gpu_acm_g12", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_pvc, {"intel_gpu_pvc", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_pvc_vg, {"intel_gpu_pvc_vg", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_mtl_u, {"intel_gpu_mtl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_mtl_h, {"intel_gpu_mtl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_arl_h, {"intel_gpu_arl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_bmg_g21, {"intel_gpu_bmg_g21", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_bmg_g31, {"intel_gpu_bmg_g31", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_lnl_m, {"intel_gpu_lnl_m", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_h, {"intel_gpu_ptl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_u, {"intel_gpu_ptl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_wcl, {"intel_gpu_wcl", GPU_FAMILY_IGPU_XE}} +}; + + +sycl_hw_info get_device_hw_info(sycl::device* device_ptr) { + sycl_hw_info res; + int32_t id = + device_ptr->get_info(); + res.device_id = id; + + res.name = device_ptr->get_info(); - return res; + syclex::architecture arch = + device_ptr->get_info(); + res.arch = arch; + + map>::iterator it = + arch2name.find(res.arch); + if (it != arch2name.end()) { + res.arch_name = it->second.first; + res.gpu_family = it->second.second; + } else { + res.arch_name = "unknown"; + res.gpu_family = GPU_FAMILY_UKNOWN; + } + + return res; } -*/ diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp index 36b140bf037..a5d20462572 100644 --- a/ggml/src/ggml-sycl/sycl_hw.hpp +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -9,18 +9,30 @@ #include namespace syclex = sycl::ext::oneapi::experimental; +using gpu_arch = sycl::ext::oneapi::experimental::architecture; + +// It's used to mark the GPU computing capacity +// The value must flow the order of performance. +enum sycl_intel_gpu_family { + GPU_FAMILY_UKNOWN = -1, + // iGPU without Xe core, before Meteor Lake iGPU(Xe) + GPU_FAMILY_IGPU_NON_XE = 0, + // iGPU with Xe core, Meteor Lake iGPU or newer. + GPU_FAMILY_IGPU_XE = 1, + // dGPU for gaming in client/data center (DG1/FLex 140 or newer). + GPU_FAMILY_DGPU_CLIENT_GAME = 2, + // dGPU for AI in cloud, PVC or newer. + GPU_FAMILY_DGPU_CLOUD = 3 +}; -// TODO: currently not used -/* struct sycl_hw_info { syclex::architecture arch; + const char* arch_name; int32_t device_id; + std::string name; + sycl_intel_gpu_family gpu_family; }; -bool is_in_vector(std::vector &vec, int item); - sycl_hw_info get_device_hw_info(sycl::device *device_ptr); -*/ - #endif // SYCL_HW_HPP From 21da84303e9cea074f16850e9a2573f68b75b48f Mon Sep 17 00:00:00 2001 From: Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> Date: Sat, 25 Apr 2026 05:14:28 -0700 Subject: [PATCH 194/249] metal : optimize Metal Tensor API usage for GGML_OP_MUL_MAT (llama/20962) * Optimize Metal Tensor API usage for matmul2d Separates the Metal Tensor API (matmul2d) path in kernel_mul_mm into its own standalone kernel, gated by GGML_METAL_HAS_TENSOR. The legacy simdgroup_matrix kernel is preserved under #else. Previously both paths were interleaved via #ifdef blocks within a single kernel, forcing the tensor path to share the legacy kernel's data layout and threadgroup memory scheme. Splitting the kernel enabled memory and dispatch optimizations that weren't possible when the two paths shared code structure. * cont : cleanup * cont : cleanup * cont : cleanup --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 26 ++- ggml/src/ggml-metal/ggml-metal-device.h | 2 + ggml/src/ggml-metal/ggml-metal-device.m | 17 +- ggml/src/ggml-metal/ggml-metal-impl.h | 13 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 7 +- ggml/src/ggml-metal/ggml-metal.metal | 233 +++++++++++++--------- 6 files changed, 189 insertions(+), 109 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 07d016d2227..d211bf79f14 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -677,7 +677,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta const ggml_type tsrc1 = op->src[1]->type; const bool bc_inp = op->src[0]->ne[0] % 32 != 0; - const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; + + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + + const bool has_tensor = ggml_metal_device_get_props(ggml_metal_library_get_device(lib))->has_tensor; + + const bool bc_out = has_tensor + ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) + : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); @@ -694,8 +702,20 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_free(cv); } - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - res.smem = bc_out ? 8192 : 4096 + 2048; + if (has_tensor) { + res.nr0 = NRA; + res.nr1 = NRB; + + const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t); + res.smem = smem_a; + } else { + res.nr0 = 64; + res.nr1 = 32; + + res.smem = bc_out ? 8192 : (4096 + 2048); + } + + res.nsg = N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index b423501358e..a6c1dab5515 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -102,6 +102,8 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 27b78c5e6d7..fe90aafe7bc 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -95,8 +95,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi struct ggml_metal_library { id obj; - id device; + ggml_metal_device_t dev; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines NSLock * lock; @@ -251,7 +251,7 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -318,7 +318,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev } res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -341,6 +341,10 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { free(lib); } +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib) { + return lib->dev; +} + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { [lib->lock lock]; @@ -405,7 +409,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ return res; } - id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id device = ggml_metal_device_get_obj(lib->dev); + id obj = [device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; @@ -699,7 +704,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; @@ -749,7 +754,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 379a8b33a14..ff74cafb5b7 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,19 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-mat threadgroups +// +// TODO: become function constants + +#define SZ_SIMDGROUP 16 +#define N_MM_NK 2 +#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK) + +#define N_MM_BLOCK_X 4 +#define N_MM_BLOCK_Y 2 +#define N_MM_SIMD_GROUP_X 2 +#define N_MM_SIMD_GROUP_Y 2 + // kernel parameters for mat-vec threadgroups // // N_R0: number of src0 rows to process per simdgroup diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e173527909a..5fa162c875c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2195,7 +2195,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1); } else { auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f38c9d2968..c372eaedeae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9306,7 +9306,137 @@ constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; // each block_q contains 16*nl weights -template +#ifdef GGML_METAL_HAS_TENSOR +template< + typename SA, typename SA_4x4, typename SA_8x8, + typename SB, typename SB_2x4, typename SB_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + // Matrix dimensions: A(M,K) x B(K,N) -> C(M,N) + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + + // Batch dimension handling + const int im = tgpig.z; + const int i12 = im % args.ne12; + const int i13 = im / args.ne12; + + // Batch offsets for srcA and srcB + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + // Tile dimensions + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + + // Tile offsets in output matrix + const int ra = tgpig.y * NRA; + const int rb = tgpig.x * NRB; + + // Threadgroup memory for dequantized A tile only + threadgroup SA * sa = (threadgroup SA *)(shmem); + + // Work-item count for A loading + constexpr int A_WORK_ITEMS = NRA * N_MM_NK; + constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; + + // tA wraps threadgroup memory + auto tA = tensor(sa, dextents(N_MM_NK_TOTAL, NRA)); + + // tB wraps device memory directly + device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11 / sizeof(T1); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + // Configure matmul operation + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor( + NRB, NRA, N_MM_NK_TOTAL, false, true, true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups> mm; + + auto cT = mm.get_destination_cooperative_tensor(); + + // Accumulate partial results over K dimension + for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) { + // === PHASE 1: Dequantization of A into threadgroup memory === + for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) { + const int row = work / N_MM_NK; + const int k_chunk = work % N_MM_NK; + const int k_pos = loop_k + k_chunk * 16; + const short k_base = k_chunk * 16; + + // Bounds check: skip device read if row is out of matrix bounds + if (ra + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + // Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4). + // MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd, + // nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned. + // Mirrors the legacy kernel's existing guard. + device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0); + + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos / (16 * nl); + const short il = (k_pos / 16) % nl; + + device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + + FOR_UNROLL (short i = 0; i < 16; i++) { + // Zero-pad A for K positions beyond valid range (handles partial K iterations) + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + // Zero-pad rows beyond matrix bounds + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // === PHASE 2: Tensor matmul === + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, rb); + + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Store result tile to output matrix (with batch offset) + // cT.store handles bounds checking via tD's extents (M, N) + device float * dstBatch = (device float *)dst + im * N * M; + + auto tD = tensor(dstBatch, dextents(M, N), array({1, M})); + cT.store(tD.slice(ra, rb)); +} + +#else + +template< + typename S0, typename S0_4x4, typename S0_8x8, + typename S1, typename S1_2x4, typename S1_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -9320,10 +9450,6 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); -#ifdef GGML_METAL_HAS_TENSOR - threadgroup float * sc = (threadgroup float *)(shmem); -#endif - constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9363,7 +9489,6 @@ kernel void kernel_mul_mm( + args.nb11*(r1 + lr1) + args.nb10*iy); -#ifndef GGML_METAL_HAS_TENSOR S0_8x8 ma[4]; S1_8x8 mb[2]; @@ -9372,19 +9497,8 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } -#else - auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); - auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); - - mpp::tensor_ops::matmul2d< - mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), - execution_simdgroups<4>> mm; - - auto cT = mm.get_destination_cooperative_tensor(); -#endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { -#ifndef GGML_METAL_HAS_TENSOR // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9454,66 +9568,6 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } -#else - // load data and store to threadgroup memory - if (is_same::value && FC_mul_mm_bc_inp) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // no need for dequantization - for (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; - } - } else { - S0_4x4 temp_a; - dequantize_func(x, il, temp_a); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - FOR_UNROLL (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; - } - } - - if (FC_mul_mm_bc_inp) { - for (short i = 0; i < 8; ++i) { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; - } - } else { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - //const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); - } -#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -9522,7 +9576,6 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); -#ifndef GGML_METAL_HAS_TENSOR // load matrices from threadgroup memory and conduct outer products threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -9549,24 +9602,10 @@ kernel void kernel_mul_mm( lsma += 8*64; lsmb += 4*64; } -#else - auto sA = tA.slice(0, 0); - auto sB = tB.slice(0, 0); - - mm.run(sB, sA, cT); -#endif } if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { // if no bounds checks on the output are needed, we can directly write to device memory -#ifdef GGML_METAL_HAS_TENSOR - device float * C = (device float *) dst + - r0 + \ - r1 * args.ne0 + im*args.ne1*args.ne0; - - auto tC = tensor, tensor_inline>(C, dextents(args.ne0, NR1)); - cT.store(tC); -#else device float * C = (device float *) dst + (r0 + 32*(sgitg & 1)) + \ (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -9574,21 +9613,15 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); } -#endif } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; -#ifdef GGML_METAL_HAS_TENSOR - auto tC = tensor, tensor_inline>(sc, dextents(NR0, NR1)); - cT.store(tC); -#else for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } -#endif threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9614,6 +9647,8 @@ kernel void kernel_mul_mm( } } +#endif // GGML_METAL_HAS_TENSOR + template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, @@ -9789,7 +9824,7 @@ kernel void kernel_mul_mm_id( const short ib = 8*sx + sy; - *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0; } } else { S0_4x4 temp_a; From da738a74f56248a3488bf9f54dfd2da67abe1196 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 25 Apr 2026 14:15:03 +0200 Subject: [PATCH 195/249] CUDA: reduce MMQ stream-k overhead (llama/22298) * CUDA: reduce MMQ stream-k overhead * use 32 bit integers for kbc --- ggml/src/ggml-cuda/mmq.cuh | 277 ++++++++++++++++++------------------- 1 file changed, 138 insertions(+), 139 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b1a319de9be..91a1b737a82 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3478,10 +3478,10 @@ template static __global__ void mul_mat_q( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const int ncols_max) { + const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 ntx) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3495,8 +3495,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. // For regular matrix multiplications this is never changed. @@ -3517,8 +3516,9 @@ static __global__ void mul_mat_q( // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { - const int wt = blockIdx.z / nchannels_y; - const int zt = blockIdx.z - wt*nchannels_y; + const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y); + const int wt = tmp2.x; + const int zt = tmp2.y; const int jt = blockIdx.y; const int it = blockIdx.x; @@ -3561,40 +3561,40 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z); return; } #endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - constexpr int ITER_K = get_iter_k(type); - - const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; + kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. - int kb0_start = kbc % blocks_per_ne00; - int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); - while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int kb0_start = fastmodulo(kbc, blocks_per_ne00); + int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc)); + while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) { + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3612,11 +3612,11 @@ static __global__ void mul_mat_q( offset_dst = 0; if (jt*mmq_x >= col_diff) { - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); continue; } @@ -3641,32 +3641,34 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); } if (kbc >= kbc_stop) { return; } - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3708,7 +3710,7 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile @@ -3717,46 +3719,37 @@ static __global__ void mul_mat_q( } template -static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, - const int32_t * expert_bounds, - float * __restrict__ dst, - const float * __restrict__ tmp_last_tile, - const int ncols_x, - const int nrows_x, - const int ncols_dst, - const size_t stride_col_dst, - const int nchannels_y, - const size_t stride_channel_dst, - const int nsamples_y, - const size_t stride_sample_dst, - const int ncols_max) { - constexpr int mmq_y = get_mmq_y_device(); - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int ITER_K = get_iter_k(type); - - constexpr int blocks_per_iter = ITER_K / qk; - const int64_t blocks_per_ne00 = ncols_x / qk; +__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1) +static __global__ void mul_mat_q_stream_k_fixup( + const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, + const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y, + const int stride_sample_dst, const uint3 ntx) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; - constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int nwarps = mmq_get_nwarps_device()/2; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + float sum[mmq_x / nwarps] = {0.0f}; + const int i = blockIdx.y*warp_size + threadIdx.x; - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; - const int nty = (nrows_x + mmq_y - 1) / mmq_y; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; - kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; - const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0; + const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } @@ -3765,11 +3758,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, // Iterate over previous blocks and sum up partial sums written to fixup buffer. // All CUDA blocks that get here must have a previous block that needs a fixup. - int64_t bidx = bidx0 - 1; - int64_t kbc_stop = kbc0; + int bidx = bidx0 - 1; + int kbc_stop = kbc0; while(true) { - int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; if (kbc == kbc_stop) { // Did not have any data. bidx--; @@ -3779,20 +3772,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, any_fixup = true; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; - } + sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) { break; } bidx--; @@ -3803,14 +3792,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } - int tmp = kbc0; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc0, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; if (!ids_dst) { const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; @@ -3818,6 +3809,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = ncols_dst - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3827,16 +3821,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[j*stride_col_dst + i] += sum[j0/nwarps]; } return; } @@ -3856,6 +3841,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = col_diff - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3865,16 +3853,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps]; } } @@ -3922,29 +3901,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; + const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits::qk); + const uint3 ntx_fd = init_fastdiv_values(ntx); + const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y); + const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y); + const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio); + const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio); + if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } return; } - const dim3 block_nums_stream_k(nsm, 1, 1); - const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles. + // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important. + const int ntiles_dst = ntx * nty * ntzw; + const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; + const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); + const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1); + + GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow. + + const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc tmp_fixup(pool); @@ -3952,40 +3946,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); } + const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1); + const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z); + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } } From 1be2adf7b3df28f450a60822ad3952316aaa6644 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Sat, 25 Apr 2026 19:58:26 -0500 Subject: [PATCH 196/249] hexagon: guard HMX clock request for v75+ platforms (llama/22377) --- ggml/src/ggml-hexagon/htp/main.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 62942f6384c..f58347304be 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -101,6 +101,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 { // Set HMX clock HAP_power_request_t request; @@ -118,6 +119,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return err; } } +#endif return AEE_SUCCESS; } From 93a3f376421cd1439e2f25e2b8687bb5685e4e15 Mon Sep 17 00:00:00 2001 From: lhez Date: Sat, 25 Apr 2026 21:21:58 -0700 Subject: [PATCH 197/249] opencl: add iq4_nl support (llama/22272) * opencl: add general support for iq4_nl * opencl: add iq4_nl gemm/gemv for adreno * opencl: pack 2 lut entries into a uint --- ggml/src/ggml-opencl/CMakeLists.txt | 5 + ggml/src/ggml-opencl/ggml-opencl.cpp | 594 ++++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 107 ++++ .../kernels/gemm_noshuffle_iq4_nl_f32.cl | 150 +++++ .../kernels/gemv_noshuffle_iq4_nl_f32.cl | 302 +++++++++ .../kernels/mul_mm_iq4_nl_f32_l4_lm.cl | 171 +++++ .../ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl | 164 +++++ .../kernels/mul_mv_iq4_nl_f32_flat.cl | 202 ++++++ 8 files changed, 1695 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 772fc537494..5ed83eeb48a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -96,6 +96,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat + mul_mv_iq4_nl_f32 + mul_mv_iq4_nl_f32_flat mul_mv_mxfp4_f32 mul_mv_mxfp4_f32_flat mul_mv_id_q4_0_f32_8x_flat @@ -110,12 +112,15 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_iq4_nl_f32_l4_lm mul_mm_q4_k_f32_l4_lm mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 + gemv_noshuffle_iq4_nl_f32 + gemm_noshuffle_iq4_nl_f32 gemv_noshuffle_general_q8_0_f32 gemv_noshuffle_q4_k_f32 gemm_noshuffle_q4_k_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8bc7ae65a6d..4d31591a4a6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -545,6 +545,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q5_K_noshuffle; cl_kernel kernel_restore_block_q5_K_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; + cl_kernel kernel_convert_block_iq4_nl, kernel_restore_block_iq4_nl; + cl_kernel kernel_convert_block_iq4_nl_noshuffle; + cl_kernel kernel_restore_block_iq4_nl_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; @@ -556,6 +559,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; + cl_kernel kernel_mul_mv_iq4_nl_f32; + cl_kernel kernel_mul_mv_iq4_nl_f32_flat; cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; @@ -594,6 +599,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; + cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; @@ -734,6 +740,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q6_K_f32; cl_kernel kernel_gemv_noshuffle_q5_k_f32; cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; + cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -954,6 +962,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); GGML_LOG_CONT("."); } @@ -1359,6 +1371,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32 = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_iq4_nl_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_iq4_nl_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32_flat = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1567,6 +1613,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_iq4_nl_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_iq4_nl_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_iq4_nl_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_iq4_nl_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q4_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2647,6 +2710,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_iq4_nl_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_iq4_nl_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3597,6 +3699,30 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_iq4_nl { + cl_mem q = nullptr; + cl_mem q_img = nullptr; + + cl_mem d = nullptr; + cl_mem d_img = nullptr; + + size_t size_q = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_iq4_nl() { + reset(); + } + + void reset() { + if (q != nullptr) { CL_CHECK(clReleaseMemObject(q)); q = nullptr; } + if (d != nullptr) { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; @@ -4097,6 +4223,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K) { @@ -4295,6 +4422,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl) { + delete e; + } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + delete e; + } for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { delete e; } @@ -4390,6 +4523,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_iq4_nl * ggml_opencl_alloc_temp_tensor_extra_iq4_nl() { + ggml_tensor_extra_cl_iq4_nl * extra; + if (temp_tensor_extras_iq4_nl.empty()) { + extra = new ggml_tensor_extra_cl_iq4_nl(); + } else { + extra = temp_tensor_extras_iq4_nl.back(); + temp_tensor_extras_iq4_nl.pop_back(); + } + + temp_tensor_extras_iq4_nl_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { ggml_tensor_extra_cl_q4_K * extra; if (temp_tensor_extras_q4_K.empty()) { @@ -4461,6 +4609,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q8_0_in_use.clear(); + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + temp_tensor_extras_iq4_nl.push_back(e); + } + temp_tensor_extras_iq4_nl_in_use.clear(); + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { temp_tensor_extras_q4_K.push_back(e); } @@ -4492,6 +4645,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_iq4_nl; + std::vector temp_tensor_extras_iq4_nl_in_use; std::vector temp_tensor_extras_q4_K; std::vector temp_tensor_extras_q4_K_in_use; std::vector temp_tensor_extras_q5_K; @@ -5123,6 +5278,87 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tensors in OpenCL backend should have been allocated and initialized"); + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_iq4_nl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_iq4_nl(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/2); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_iq4_nl_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + #endif + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } +#endif + return; + } if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5775,6 +6011,78 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl_iq4_nl * extra = (ggml_tensor_extra_cl_iq4_nl *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*(ggml_blck_size(tensor->type)/2); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose q, d back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl_noshuffle; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; @@ -9840,6 +10148,178 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % 32 == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_iq4_nl->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -10634,6 +11114,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -10738,6 +11219,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // iq4_nl x fp32 + if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); + return; + } + // q8_0 x fp32 if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && enable_adreno_trans_weight(backend_ctx, src0)) { @@ -11302,6 +11789,48 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_IQ4_NL: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q4_K: { if (ne11 < 32) { break; @@ -11829,6 +12358,70 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_IQ4_NL: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; } @@ -12131,6 +12724,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_Q2_K) { // Each SIMD group produces N_DST values in the result. Assuming each // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 39af32d282b..f3937d8304c 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -87,6 +87,17 @@ struct block_q6_K { half d; // super-block scale }; +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +#define QK4_NL 32 + +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -895,3 +906,99 @@ kernel void kernel_restore_block_q6_K_noshuffle( b->scales[i] = s[i]; } } + +//------------------------------------------------------------------------------ +// kernel_convert_block_iq4_nl +// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA). +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_iq4_nl( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_NL/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_iq4_nl( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK4_NL/2; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_convert_block_iq4_nl_noshuffle( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } +} + +kernel void kernel_restore_block_iq4_nl_noshuffle( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK4_NL/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..6869d822862 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,150 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_iq4_nl_f32( + global const ushort * src0_q, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..9386bf25a6f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_NL 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_iq4_nl_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_NL); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl new file mode 100644 index 00000000000..11ff7f8d9dc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl @@ -0,0 +1,171 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_mul_mm_iq4_nl_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + // IQ4_NL: use lookup table instead of linear (nibble - 8) + float4 v1 = (float4)(kvalues_iq4nl[(q.s0 )&0x0F], kvalues_iq4nl[(q.s1 )&0x0F], + kvalues_iq4nl[(q.s2 )&0x0F], kvalues_iq4nl[(q.s3 )&0x0F])*d; + float4 v2 = (float4)(kvalues_iq4nl[(q.s0>>4)&0x0F], kvalues_iq4nl[(q.s1>>4)&0x0F], + kvalues_iq4nl[(q.s2>>4)&0x0F], kvalues_iq4nl[(q.s3>>4)&0x0F])*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl new file mode 100644 index 00000000000..a6a325cd729 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl @@ -0,0 +1,164 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// Compute inner product between half a block of iq4_nl and 16 floats (yl). +// il indicates where the quants begin (0 or 8). +inline float block_iq4_nl_dot_y( + global struct block_iq4_nl * qb_curr, + private float * yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global uchar * qs = qb_curr->qs + il; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup group works on 4 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_iq4_nl * x = (global struct block_iq4_nl *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_NL + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_iq4_nl_dot_y(x+ib+row*nb, yl, il); + } + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl new file mode 100644 index 00000000000..8c5b3f52e42 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl @@ -0,0 +1,202 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +// Compute dot product between half a block of iq4_nl quants and activations. +// x points to the quant bytes, dh points to the scale. +// yl has 16 activation values: [0..7] for low nibbles, [8..15] for high nibbles. +// il indicates offset into the quant bytes (0 or 8). +inline float block_iq4_nl_dot_y_flat( + global uchar * x, + global half * dh, + private float * yl, + int il +) { + float d = *dh; + global uchar * qs = x + il; + float acc = 0.f; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each subgroup works on 8 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_NL/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_NL/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_NL + il; + + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + sumf.s0 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 0*nb*QK4_NL/2, d + ib + 0*nb, yl, il); + sumf.s1 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 1*nb*QK4_NL/2, d + ib + 1*nb, yl, il); + sumf.s2 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 2*nb*QK4_NL/2, d + ib + 2*nb, yl, il); + sumf.s3 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 3*nb*QK4_NL/2, d + ib + 3*nb, yl, il); + + sumf.s4 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 4*nb*QK4_NL/2, d + ib + 4*nb, yl, il); + sumf.s5 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 5*nb*QK4_NL/2, d + ib + 5*nb, yl, il); + sumf.s6 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 6*nb*QK4_NL/2, d + ib + 6*nb, yl, il); + sumf.s7 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 7*nb*QK4_NL/2, d + ib + 7*nb, yl, il); + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} From 4e11277a198de0f0ccc9a6fbfd6e943a7602b546 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 26 Apr 2026 06:27:50 +0000 Subject: [PATCH 198/249] ggml-cpu: optimize avx2 q6_k (llama/22345) --- ggml/src/ggml-cpu/arch/x86/quants.c | 46 ++++++++++++----------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 0a3e071e57c..94b19b82bbc 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -2300,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #if defined __AVX2__ - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m15 = _mm256_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); @@ -2314,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * GGML_RESTRICT qh = x[i].qh; const int8_t * GGML_RESTRICT q8 = y[i].qs; + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m256i scales_16 = _mm256_cvtepi8_epi16(scales); + const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5); __m256i sumi = _mm256_setzero_si256(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2); + const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48)); + const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2); - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); @@ -2372,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } + sumi = _mm256_sub_epi32(sumi, q8sclsub); acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); } From 2f3df42cddca762047c2884342b683549420be71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 26 Apr 2026 08:28:14 +0200 Subject: [PATCH 199/249] ggml-cpu : re-enable fast gelu_quick_f16 (llama/22339) --- ggml/src/ggml-cpu/vec.h | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index a0375a28de0..bcd68da9aa9 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1036,12 +1036,12 @@ inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} +inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_quick_f16[i16[i]]; + } +} #ifdef GGML_GELU_QUICK_FP16 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { @@ -1060,13 +1060,6 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * } #endif -inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - for (int i = 0; i < n; ++i) { - float v = GGML_CPU_FP16_TO_FP32(x[i]); - y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v)))); - } -} - // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); From 9bf6c3c8602b976b79139d908fd63ffe048749ba Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 26 Apr 2026 09:21:45 +0200 Subject: [PATCH 200/249] CUDA: better coalesce data-access for contiguous concat (llama/22330) Also, distribute all elements across CTAs evenly instead of launching one CTA per dim --- ggml/src/ggml-cuda/concat.cu | 141 +++++++++++++++-------------------- 1 file changed, 62 insertions(+), 79 deletions(-) diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index e9ffd274b99..102f944f924 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,96 +1,79 @@ #include "concat.cuh" // contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (nidx < ne00) { // src0 - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - (nidx - ne00) + - blockIdx.y * (ne0 - ne00) + - blockIdx.z * (ne0 - ne00) * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} - -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } +template +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, + const float * y, + float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); + + const int64_t n = ne0 * ne1 * ne2; + + for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { + if constexpr (dim == 0) { + const int64_t row = i / ne0; + const int64_t i0 = i - row * ne0; + + if (i0 < ne00) { + dst[i] = x[row * ne00 + i0]; + } else { + dst[i] = y[row * (ne0 - ne00) + (i0 - ne00)]; + } + } else if constexpr (dim == 1) { + const int64_t dst_plane = ne0 * ne1; + const int64_t src0_plane = ne0 * ne01; + const int64_t src1_plane = dst_plane - src0_plane; + const int64_t i2 = i / dst_plane; + const int64_t i01 = i - i2 * dst_plane; + + if (i01 < src0_plane) { + dst[i] = x[i2 * src0_plane + i01]; + } else { + dst[i] = y[i2 * src1_plane + (i01 - src0_plane)]; + } + } else { + const int64_t src0_size = ne0 * ne1 * ne02; - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.y < (unsigned)ne01) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - (blockIdx.y - ne01) * ne0 + - blockIdx.z * ne0 * (gridDim.y - ne01); - dst[offset_dst] = y[offset_src]; + if (i < src0_size) { + dst[i] = x[i]; + } else { + dst[i] = y[i - src0_size]; + } + } } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.z < (unsigned)ne02) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - blockIdx.y * ne0 + - (blockIdx.z - ne02) * ne0 * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} +static void concat_f32_cuda(const float * x, + const float * y, + float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { + const int64_t n = ne0 * ne1 * ne2; + const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; -static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { - int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { - concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + concat_f32_cont<0> + <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_dim1<<>>(x, y, dst, ne0, ne01); + concat_f32_cont<1> + <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_dim2<<>>(x, y, dst, ne0, ne02); + concat_f32_cont<2><<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) From 7296b9c7faec4df1e683d0ef652c3ed4c79ac6ff Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 26 Apr 2026 17:04:40 +0530 Subject: [PATCH 201/249] Fix recurrent state serialization for partial reads and writes (llama/22362) The previous code worked only for full tensor reads and writes and was hitting `GGML_ASSERT(size == ggml_nbytes(tensor)); ` assert when tested with llama-server. --- ggml/src/ggml-backend-meta.cpp | 66 +++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 6d22f3421b1..41a61775bd6 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1205,40 +1205,57 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg if (split_state.n_segments != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; std::vector simple_offsets(n_bufs, 0); if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, + r_count, simple_tensor->nb[1], tensor->nb[1]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[1] == size); + GGML_ASSERT(offset_data*r_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, + r_count, simple_tensor->nb[2], tensor->nb[2]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[2] == size); + GGML_ASSERT(offset_data*r_count == size); return; } @@ -1295,40 +1312,57 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co if (split_state.n_segments != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; std::vector simple_offsets(n_bufs, 0); if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, + r_count, simple_tensor->nb[1], tensor->nb[1]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[1] == size); + GGML_ASSERT(offset_data*r_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, + r_count, simple_tensor->nb[2], tensor->nb[2]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[2] == size); + GGML_ASSERT(offset_data*r_count == size); return; } From 1478450e61487ae2cd44916d902ccd626539de47 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Sun, 26 Apr 2026 09:26:28 -0700 Subject: [PATCH 202/249] add performance-portable tuning for register-tile and subgroup matmul (llama/22241) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 16ebc32cbc7..503171ee14f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -26,20 +26,23 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 #define WEBGPU_MUL_MAT_WG_SIZE_M 8 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 +#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension #define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 @@ -1734,13 +1737,24 @@ class ggml_webgpu_shader_lib { // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + const bool is_quant = ggml_is_quantized(context.src0->type); + + uint32_t tile_k; + if (key.use_subgroup_matrix) { + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT + : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + } else { + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT + : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + } + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); // Subgroup matrix specifics if (key.use_subgroup_matrix) { + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); @@ -1760,12 +1774,13 @@ class ggml_webgpu_shader_lib { if (!key.use_subgroup_matrix) { defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); } auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->use_subgroup_matrix = key.use_subgroup_matrix; @@ -1962,10 +1977,15 @@ class ggml_webgpu_shader_lib { defines.push_back("SCALAR"); + // mul_mat_id is register-tile only. + const uint32_t tile_k = ggml_is_quantized(context.src0->type) + ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT + : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); @@ -1976,7 +1996,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; From f5c3ce17d563b7a86561062c1cd82ad7b1ebdd24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Mon, 27 Apr 2026 08:30:55 +0200 Subject: [PATCH 203/249] ggml : use 64 bytes aligned tile buffers (llama/21058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit | Model | Test | t/s OLD | t/s NEW | Speedup | |:---------------------------------|:-------|----------:|----------:|----------:| | qwen35 0.8B BF16 | pp512 | 584.59 | 595.41 | 1.02 | | qwen35 0.8B BF16 | tg128 | 52.23 | 52.82 | 1.01 | | qwen35 0.8B IQ2_M - 2.7 bpw | pp512 | 260.64 | 261.70 | 1.00 | | qwen35 0.8B IQ2_M - 2.7 bpw | tg128 | 81.17 | 80.89 | 1.00 | | qwen35 0.8B IQ2_XXS - 2.0625 bpw | pp512 | 302.36 | 302.56 | 1.00 | | qwen35 0.8B IQ2_XXS - 2.0625 bpw | tg128 | 84.93 | 85.12 | 1.00 | | qwen35 0.8B IQ3_XXS - 3.0625 bpw | pp512 | 263.22 | 260.01 | 0.99 | | qwen35 0.8B IQ3_XXS - 3.0625 bpw | tg128 | 80.29 | 78.94 | 0.98 | | qwen35 0.8B IQ4_NL - 4.5 bpw | pp512 | 728.65 | 742.09 | 1.02 | | qwen35 0.8B IQ4_NL - 4.5 bpw | tg128 | 82.39 | 84.46 | 1.03 | | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512 | 681.33 | 677.06 | 0.99 | | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128 | 80.18 | 79.28 | 0.99 | | qwen35 0.8B Q2_K_M | pp512 | 413.28 | 415.94 | 1.01 | | qwen35 0.8B Q2_K_M | tg128 | 81.90 | 82.78 | 1.01 | | qwen35 0.8B Q3_K_M | pp512 | 493.17 | 495.08 | 1.00 | | qwen35 0.8B Q3_K_M | tg128 | 82.75 | 83.23 | 1.01 | | qwen35 0.8B Q3_K_S | pp512 | 429.35 | 427.64 | 1.00 | | qwen35 0.8B Q3_K_S | tg128 | 86.69 | 87.02 | 1.00 | | qwen35 0.8B Q4_0 | pp512 | 783.46 | 782.32 | 1.00 | | qwen35 0.8B Q4_0 | tg128 | 88.23 | 87.90 | 1.00 | | qwen35 0.8B Q4_1 | pp512 | 741.71 | 729.76 | 0.98 | | qwen35 0.8B Q4_1 | tg128 | 85.44 | 86.01 | 1.01 | | qwen35 0.8B Q4_K_M | pp512 | 676.24 | 681.31 | 1.01 | | qwen35 0.8B Q4_K_M | tg128 | 76.59 | 77.06 | 1.01 | | qwen35 0.8B Q4_K_S | pp512 | 683.12 | 688.81 | 1.01 | | qwen35 0.8B Q4_K_S | tg128 | 80.50 | 81.19 | 1.01 | | qwen35 0.8B Q5_K_M | pp512 | 635.33 | 642.11 | 1.01 | | qwen35 0.8B Q5_K_M | tg128 | 72.07 | 72.49 | 1.01 | | qwen35 0.8B Q5_K_S | pp512 | 660.95 | 658.18 | 1.00 | | qwen35 0.8B Q5_K_S | tg128 | 72.19 | 72.95 | 1.01 | | qwen35 0.8B Q6_K | pp512 | 647.97 | 638.84 | 0.99 | | qwen35 0.8B Q6_K | tg128 | 72.83 | 72.49 | 1.00 | | qwen35 0.8B Q8_0 | pp512 | 805.01 | 785.49 | 0.98 | | qwen35 0.8B Q8_0 | tg128 | 70.10 | 70.13 | 1.00 | Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/amx/mmq.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 93a6d397f79..d9383a04be8 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2005,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int lda = KB * sizeof(TA); //const int ldb = KB * sizeof(TB); - static thread_local packed_B_t Tile0[TILE_N * TILE_K]; - static thread_local packed_B_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; - static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; - static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; // double buffering C to interleave avx512 and amx int32_t * C_cur = TileC0; @@ -2187,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int m1 = std::max(M - TILE_M, 0); //const int lda = KB * sizeof(TA); - static thread_local int8_t Tile0[TILE_N * TILE_K]; - static thread_local int8_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; // mat mul result for each group - static thread_local int32_t Tile4[TILE_M * TILE_N]; - static thread_local int32_t Tile5[TILE_M * TILE_N]; - static thread_local int32_t Tile6[TILE_M * TILE_N]; - static thread_local int32_t Tile7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N]; // sum of each QK_K block, contains 8 groups, int32 - static thread_local int32_t Sumi4[TILE_M * TILE_N]; - static thread_local int32_t Sumi5[TILE_M * TILE_N]; - static thread_local int32_t Sumi6[TILE_M * TILE_N]; - static thread_local int32_t Sumi7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N]; const int k_group_size = std::is_same::value ? 16 : 32; for (int i = 0; i < KB; ++i) { From c9ba41397cb81e13f98d896a5a63fd5e9a1ea8dc Mon Sep 17 00:00:00 2001 From: unraido <127105806+unraido@users.noreply.github.com> Date: Mon, 27 Apr 2026 23:25:09 +0900 Subject: [PATCH 204/249] fix: rpc-server cache may not work in Windows environments (llama/22394) * fix: create directory and log cache file name. * Remove GGML_LOG_INFO conditional compilation. --------- Co-authored-by: kotaro --- ggml/src/ggml-rpc/ggml-rpc.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 2ded7397868..505bec73d37 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1101,7 +1101,7 @@ bool rpc_server::set_tensor(const std::vector & input) { fs::path cache_file = fs::path(cache_dir) / hash_str; std::ofstream ofs(cache_file, std::ios::binary); ofs.write((const char *)data, size); - GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); return true; From f675a8c9264682c720ae0d3b7badb06227065cc3 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Mon, 27 Apr 2026 08:25:45 -0700 Subject: [PATCH 205/249] add fast mat-vec kernels for i-quants (llama/22344) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 18 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 514 ++++++++++++++++++ 3 files changed, 543 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 503171ee14f..08ea2906ada 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1615,6 +1615,24 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } break; } } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index bcec20c1a11..d6d7dbdaf3c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1391,6 +1391,17 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q2_K: use_fast = true; break; + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + use_fast = is_vec; + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 97c9f6d7a09..c2eafee6c75 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -812,6 +812,520 @@ fn main( } #endif +#ifdef MUL_ACC_IQ1_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 50 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base)); + let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; + let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ1_M +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 56 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let sc_lo = load_u32_at_src0(block_byte_base + 48u); + let sc_hi = load_u32_at_src0(block_byte_base + 52u); + let sc0 = sc_lo & 0xFFFFu; + let sc1 = (sc_lo >> 16u) & 0xFFFFu; + let sc2 = sc_hi & 0xFFFFu; + let sc3 = (sc_hi >> 16u) & 0xFFFFu; + let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); + let d = f32(bitcast>(d_bits)[0]); + + let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), + select(sc0, sc1, sub_blk >= 2u), + sub_blk < 4u); + + let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); + let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; + let qh_lo = qh & 0xFFu; + let qh_hi = (qh >> 8u) & 0xFFu; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ2_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 66 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let ls = aux_hi >> 28u; + let db = d * (0.5 + f32(ls)) * 0.25; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ2_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 74 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(scales_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ2_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 82 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(sc_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ3_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 98 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); + let ls = aux >> 28u; + let db = d * (0.5 + f32(ls)) * 0.5; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ3_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); + let sc_word = load_u32_at_src0(block_byte_base + 106u); + let scales_byte = get_byte(sc_word, sub_blk / 2u); + let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(sub_scale)); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ4_NL +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + i + 16u]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ4_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 136 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let y_offset = sub_blk * 32u + half * 16u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let scales_h = load_u16_at_src0(block_byte_base + 2u); + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let sl_byte = get_byte(scales_l_word, sub_blk / 2u); + let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; + let ls = i32(sl | (sh_bits << 4u)); + let dl = d * f32(ls - 32); + + let qs_byte_off = 8u + sub_blk * 16u; + let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); + let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); + let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); + let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); + + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + } + acc[row] += row_sum; + } + } + } +#endif + #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let subgroup_total = subgroupAdd(acc[row]); From 9c233f11f09c0ea3d7d8df0056c3c312ef9248f3 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Mon, 27 Apr 2026 15:50:59 -0700 Subject: [PATCH 206/249] ggml-webgpu: add Q1_0 support (llama/22374) * add fast matmul matvec q1_0 kernel * ggml-webgpu: drop redundant zero-fills in Q1_0 shmem init --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 9 +++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +++ .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 18 ++++++++++ .../wgsl-shaders/mul_mat_decls.tmpl | 33 +++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 32 ++++++++++++++++++ 5 files changed, 94 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 08ea2906ada..fb2c9527f3c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1287,6 +1287,7 @@ class ggml_webgpu_shader_lib { std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); switch (key.src_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1323,7 +1324,9 @@ class ggml_webgpu_shader_lib { defines.push_back("DST_TYPE=f32"); - if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + if (key.src_type == GGML_TYPE_Q1_0) { + defines.push_back("BLOCK_SIZE=128u"); + } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { @@ -1657,7 +1660,9 @@ class ggml_webgpu_shader_lib { uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; - if (key.src0_type >= GGML_TYPE_Q2_K) { + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d6d7dbdaf3c..6d861c0c781 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1389,6 +1389,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q5_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q1_0: use_fast = true; break; case GGML_TYPE_IQ1_S: @@ -3736,6 +3737,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm static bool ggml_webgpu_supported_qtype(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3830,6 +3832,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3868,6 +3871,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 1415798fa6b..5710cd35469 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif +#ifdef Q1_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 18; + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 4u; j++) { + let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u); + let dst_base128 = dst_base + offset * 128u + j * 32u; + for (var k: u32 = 0; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + for (var bit: u32 = 0; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + dst[dst_base128 + k * 8u + bit] = w; + } + } + } +} +#endif + #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 5a323818260..15b22c4f731 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 #endif // INIT_SRC1_SHMEM_FLOAT #endif +#ifdef INIT_SRC0_SHMEM_Q1_0 +const BLOCK_SIZE = 128u; +const BLOCK_SIZE_BYTES = 18u; +const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let tile_m = i / TILE_K; + let tile_k_start = i % TILE_K; + let global_m = offset_m + tile_m; + let global_k_start = k_outer + tile_k_start; + + if (global_m >= params.m) { + break; + } + + let block_k = global_k_start / BLOCK_SIZE; + let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u; + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu; + + for (var bit = 0u; bit < NQ; bit++) { + let global_k = global_k_start + bit; + if (global_k < params.k) { + shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q1_0 + #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 18u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index c2eafee6c75..a8000439bfb 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -128,6 +128,38 @@ fn main( } #endif +#ifdef MUL_ACC_Q1_0 +#define BLOCK_SIZE 128 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 16 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[bit]; + } + acc[row] += row_sum; + } + } + } +#endif + #ifdef MUL_ACC_Q4_0 #define BLOCK_SIZE 32 #define BLOCK_SIZE_BYTES 18 From 70e4c0aec058a27f7abf0df1dd7a9660ba3bd4a0 Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 28 Apr 2026 14:27:22 +0800 Subject: [PATCH 207/249] CANN: add new ops, optimize existing ops (llama/21204) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New operators: - GGML_OP_SET: implement via aclnnInplaceCopy on target region - GGML_OP_CUMSUM: implement via aclnnCumsum - GGML_OP_FILL: implement via aclnnInplaceFillScalar - GGML_OP_DIAG: implement via aclnnInplaceCopy on diagonal strides - GGML_OP_TRI (lower/lower_diag/upper_diag/upper): implement via aclnnTril(-1/0) and aclnnTriu(0/1) with appropriate diagonal offsets - GGML_OP_SOLVE_TRI: implement via aclnnTriangularSolve - GGML_UNARY_OP_SOFTPLUS: implement via aclnnSoftplus Optimizations: - GLU (SwiGLU/GeGLU/GeGLU_ERF/GeGLU_QUICK): fuse with aclnnSwiGlu / aclnnGeGluV3 when applicable; fallback conditions now checked inside each function rather than at the call site - CROSS_ENTROPY_LOSS: replace 5-kernel sequence (LogSoftmax→Mul→ ReduceSum×2→Muls) with single aclnnSoftmaxCrossEntropyWithLogits call - L2_NORM: fix in-place ClampMin on norm result (was clamping wrong tensor); add eps clamping before division to avoid divide-by-zero - PAD_REFLECT_1D: eliminate per-ne[3] loop; assert contiguity and call ReflectionPad1d once on the full 4-D view; remove redundant nb copies - GET_ROWS: replace IndexSelect with GatherV2 per batch slice; refactor helper into gather_batched lambda with batch loop inlined - SET_ROWS: replace IndexCopy with InplaceIndexCopy per batch slice; refactor helper into scatter_batched lambda with batch loop inlined - OUT_PROD: replace O(ne[3]*ne[2]*ne[1]) Ger+InplaceAdd loop with per-slice Matmul loop (src0 @ src1^T); handles strided-broadcast batch dims where ne02/ne03 may differ from ne2/ne3 - backend memset_tensor: implement via aclrtMemset (was NULL) Bug fixes: - COUNT_EQUAL: use non-inplace EqTensor into a same-type temporary buffer instead of InplaceEqTensor, avoiding corruption of src0 - ACL graph cache (USE_ACL_GRAPH): restore node_type and src_type[] fields in ggml_graph_node_properties; has_matching_properties() was missing type checks, causing F16 and BF16 tensors (same nb[0]=2) to incorrectly share cached graphs and produce wrong results (ERR≈679) - graph cache op_params matching: compare full GGML_MAX_OP_PARAMS bytes so that ops differing only in parameters are not incorrectly replayed from cache --- ggml/src/ggml-cann/aclnn_ops.cpp | 768 +++++++++++++++++++++---------- ggml/src/ggml-cann/aclnn_ops.h | 56 +++ ggml/src/ggml-cann/ggml-cann.cpp | 66 ++- 3 files changed, 628 insertions(+), 262 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index a950475fc3b..2dc0f40917d 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -25,6 +25,7 @@ #include "ggml-impl.h" #include "ggml.h" + #include #include #include @@ -45,7 +46,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -62,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -69,11 +73,15 @@ #include #include #include +#include #include #include #include #include +#include #include +#include +#include #include #include #include @@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::functionsrc[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + // aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + size_t elem_size = ggml_element_size(src0); + + // src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3] + // CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last). + int64_t ne0_x2 = src0->ne[0]; + int64_t ne1 = src0->ne[1]; + int64_t ne23 = src0->ne[2] * src0->ne[3]; + int64_t src3d_ne[] = { ne0_x2, ne1, ne23 }; + size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] }; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + elem_size, src3d_ne, src3d_nb, 3); + + // dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3] + int64_t ne0 = dst->ne[0]; + int64_t dst3d_ne[] = { ne0, ne1, ne23 }; + size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] }; + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + elem_size, dst3d_ne, dst3d_nb, 3); + + // CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0. + GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get()); +} + +// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim), +// activates the LEFT half with GELU, multiplies by right half. +// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention. +// outGelu is a required-but-discard output buffer. +// +// Falls back to the generic two-kernel path when src[1] != nullptr (two +// independent halves) or swapped != 0 (reversed activation order), as +// aclnnGeGluV3 only handles the single interleaved tensor in standard order. +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) { + auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst); + }; + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + if (dst->src[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + // aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + // Allocate a temporary buffer for the required outGelu output (same shape as dst). + // Build contiguous strides since the pool allocation is a fresh buffer. + size_t elem_size = ggml_element_size(dst); + int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] }; + size_t nb[GGML_MAX_DIMS]; + nb[0] = elem_size; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1]; + ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size); + + acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor( + gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS); + // V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention. + // GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true, + acl_dst.get(), acl_gelu_out.get()); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -445,28 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes); void * buffer = temp_buffer_allocator.get(); - int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; - size_t div_nb[GGML_MAX_DIMS]; - div_nb[0] = sizeof(float); + int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; + size_t norm_nb[GGML_MAX_DIMS]; + norm_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; ++i) { - div_nb[i] = div_nb[i - 1] * div_ne[i - 1]; + norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1]; } - acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); std::vector norm_dims = { 3 }; acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size()); float p_value = 2.0f; acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get()); + + ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool()); + acl_tensor_ptr acl_clamped; - // Clamp norm to at least eps: scale = 1/fmaxf(norm, eps) - acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); - float flt_max = FLT_MAX; - acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get()); + if (eps > 0.0f) { + void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes); + acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); + acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get()); + } - GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); + aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get(); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get()); } void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -482,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * logits_nb[1] = logits_nb[0] * logits_ne[0]; acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - size_t log_softmax_type_size = sizeof(float); - int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size; - ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes); - void * log_softmax_buffer = log_softmax_allocator.get(); - - int64_t log_softmax_ne[] = { nc, nr }; - size_t log_softmax_nb[2]; - log_softmax_nb[0] = log_softmax_type_size; - log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0]; - acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, - log_softmax_ne, log_softmax_nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get()); - int64_t labels_ne[] = { nc, nr }; size_t labels_nb[2]; labels_nb[0] = ggml_type_size(src1->type); labels_nb[1] = labels_nb[0] * labels_ne[0]; acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2); - size_t mul_type_size = sizeof(float); - int64_t mul_n_bytes = nr * nc * mul_type_size; - ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes); - void * mul_buffer = mul_allocator.get(); - - int64_t mul_ne[] = { nc, nr }; - size_t mul_nb[2]; - mul_nb[0] = mul_type_size; - mul_nb[1] = mul_nb[0] * mul_ne[0]; - acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get()); + size_t loss_per_sample_type_size = sizeof(float); + int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size; + ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes); + void * loss_per_sample_buffer = loss_per_sample_allocator.get(); - size_t sum_per_sample_type_size = sizeof(float); - int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size; - ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes); - void * sum_per_sample_buffer = sum_per_sample_allocator.get(); + int64_t loss_per_sample_ne[] = { nr }; + size_t loss_per_sample_nb[1]; + loss_per_sample_nb[0] = loss_per_sample_type_size; + acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor( + loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1); - int64_t sum_per_sample_ne[] = { nr }; - size_t sum_per_sample_nb[1]; - sum_per_sample_nb[0] = sum_per_sample_type_size; - acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor( - sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1); + size_t backprop_n_bytes = nr * nc * sizeof(float); + ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes); + void * backprop_buffer = backprop_allocator.get(); + acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - std::vector sum_dims = { 1 }; - acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size()); - bool keep_dims = false; - - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT, - acl_sum_per_sample.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(), + acl_loss_per_sample.get(), acl_backprop.get()); size_t total_sum_type_size = sizeof(float); int64_t total_sum_n_bytes = 1 * total_sum_type_size; @@ -547,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * std::vector total_sum_dims = { 0 }; acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size()); + bool keep_dims = false; - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, acl_total_sum.get()); - float value = -1.0f / static_cast(nr); + float value = 1.0f / static_cast(nr); acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1); @@ -589,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { acl_mean_out.get(), acl_rstd_out.get()); } +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 }; + + // Create a view of dst at the target offset with src1's dimensions + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); + acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1); + + if (!inplace) { + // First copy src0 to dst entirely + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK( + aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + } + + // Copy src1 into the target region of dst + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get()); +} + void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -652,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + // GGML cumsum operates along dim 0 (innermost / ne[0]). + // ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0], + // so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3, + ggml_cann_type_mapping(dst->type), acl_dst.get()); +} + +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular + ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3] + + acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1); + acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst); + + // mOut: triangular copy of A (required output), same shape as A. + const size_t a_bytes = ggml_nbytes(src0); + ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes); + acl_tensor_ptr acl_m = ggml_cann_create_tensor( + m_alloc.get(), ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); + + // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false. + GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve, + acl_b.get(), acl_a.get(), false, false, false, + acl_x.get(), acl_m.get()); +} + +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(src->ne[1] == 1); + + const int64_t N = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + // Fill dst with zeros. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + { + float zero = 0.0f; + acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get()); + } + + // Copy src vector onto the diagonal of dst via strided views. + // src viewed as [N, n_batch], contiguous strides. + int64_t ne_vec[2] = { N, n_batch }; + size_t nb_src_vec[2] = { nb_f32, N * nb_f32 }; + // dst diagonal view: stride (N+1)*4 steps along the diagonal. + size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 }; + + acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2); + acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get()); +} + +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + float c = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get()); +} + +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + const int64_t S = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + int64_t ne3d[3] = { S, S, n_batch }; + size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 }; + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + + switch (ttype) { + case GGML_TRI_TYPE_LOWER: + // Tril(-1): preserve row > col (strict lower), zero upper + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER_DIAG: + // Triu(0): preserve row <= col (upper + diagonal), zero strict lower. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER: + // Triu(1): preserve row < col (strict upper), zero lower + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get()); + break; + case GGML_TRI_TYPE_LOWER_DIAG: + // Tril(0): preserve row >= col (lower + diagonal), zero strict upper. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + default: + GGML_ABORT("unsupported tri type"); + } +} + void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); @@ -1695,152 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get()); } -/** - * @brief Performs index select operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexSelect` operation along a specific dimension - * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`). - * It iterates over the last two dimensions of the source tensor, creates the corresponding - * CANN tensors for the source, index, and output slices, and executes the `IndexSelect` - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where the output tensor data will be written. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying the indices to select from the source tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_select_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get()); - } - } -} - -/** - * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexCopy` operation along a specific dimension of the - * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`) - * to positions specified by the index tensor (`index`). - * It iterates over the last two dimensions of the tensors, creates the corresponding - * CANN tensors for source, index, and destination slices, and performs the index copy - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data to be copied. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where values will be copied to. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying target positions in the destination tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get()); - } - } -} void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src + ggml_tensor * src0 = dst->src[0]; // weight ggml_tensor * src1 = dst->src[1]; // index GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); + // n_idx: number of row indices per (i2, i3) batch slice. + // ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1. + const int64_t n_idx = src1->ne[0]; + + // Gather all (i2, i3) batch slices from src into dst. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0]. + // GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis). + // nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * nb) { + int64_t src_ne[2] = { src0->ne[0], src0->ne[1] }; + size_t src_nb_2d[2] = { nb[0], nb[1] }; + int64_t dst_ne[2] = { src0->ne[0], n_idx }; + size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] }; + int64_t idx_ne[1] = { n_idx }; + size_t idx_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) { + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * nb[3] + i2 * nb[2], + acl_type, type_size, src_ne, src_nb_2d, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + idx_ne, idx_nb, 1); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, dst_ne, dst_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get()); + } + } + }; + switch (src0->type) { case GGML_TYPE_BF16: case GGML_TYPE_F16: case GGML_TYPE_F32: if (src0->type == dst->type) { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + gather_batched(src0->data, + ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + src0->nb); } else { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = dst->nb[0]; + // Cast src0 to dst type, then gather. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_element_size(dst)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = - ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + gather_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); } break; case GGML_TYPE_Q8_0: { - // add 1 dim for bcast mul. + // Dequantize Q8_0 to dst type, then gather. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1]; int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; - int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] - weight_ne[0] = QK8_0; - weight_ne[1] = src0->ne[0] / QK8_0; - weight_nb[0] = sizeof(int8_t); - weight_nb[1] = weight_nb[0] * weight_ne[0]; + weight_ne[0] = QK8_0; + weight_ne[1] = src0->ne[0] / QK8_0; + weight_nb[0] = sizeof(int8_t); + weight_nb[1] = weight_nb[0] * weight_ne[0]; for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; scale_nb[0] = sizeof(uint16_t); @@ -1849,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); - ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(), - ggml_nelements(src0) * ggml_type_size(dst->type)); - acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), - weight_ne, weight_nb, GGML_MAX_DIMS + 1); - acl_tensor_ptr acl_scale_tensor = - ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, - GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); - acl_tensor_ptr dequant_tensor = - ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get()); - dequant_nb[0] = ggml_type_size(dst->type); + const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t); + ggml_cann_pool_alloc dequant_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), + weight_ne, weight_nb, GGML_MAX_DIMS + 1); + acl_tensor_ptr acl_scale = ggml_cann_create_tensor( + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, + GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); + acl_tensor_ptr acl_dequant = ggml_cann_create_tensor( + dequant_allocator.get(), ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); + aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get()); + + // Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides. dequant_ne = src0->ne; + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, - dst->nb, src1, dst->type); + gather_batched(dequant_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + dequant_nb); break; } default: @@ -1883,31 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src - ggml_tensor * src1 = dst->src[1]; // index + ggml_tensor * src0 = dst->src[0]; // source values + ggml_tensor * src1 = dst->src[1]; // row indices + + // n_idx: number of source rows to scatter per batch slice. + // ggml guarantees: src0->ne[1] == src1->ne[0]. + const int64_t n_idx = src1->ne[0]; + + // Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst. + // InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis). + // src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * src_nb) { + int64_t d_ne[2] = { dst->ne[0], dst->ne[1] }; + size_t d_nb[2] = { dst->nb[0], dst->nb[1] }; + int64_t s_ne[2] = { dst->ne[0], n_idx }; + size_t s_nb_2d[2] = { src_nb[0], src_nb[1] }; + int64_t i_ne[1] = { n_idx }; + size_t i_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) { + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, d_ne, d_nb, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + i_ne, i_nb, 1); + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * src_nb[3] + i2 * src_nb[2], + acl_type, type_size, s_ne, s_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get()); + } + } + }; switch (dst->type) { case GGML_TYPE_F32: - { - aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type); - break; - } + scatter_batched(src0->data, + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->nb); + break; case GGML_TYPE_F16: case GGML_TYPE_BF16: { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(uint16_t); + // Cast src0 (F32) to dst type first. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + scatter_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); break; } default: @@ -3268,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst int64_t paddingsArray[2] = { opts[0], opts[1] }; acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2); - for (int64_t i = 0; i < src0->ne[3]; i++) { - acl_tensor_ptr acl_src = - ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type), - ggml_element_size(src0), src0->ne, src0->nb, 3); + // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3 + // is contiguous with respect to dim2 in both src and dst. + GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]); + GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]); - acl_tensor_ptr acl_dst = - ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type), - ggml_element_size(dst), dst->ne, dst->nb, 3); + int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] }; + int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] }; - GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); - } + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src_ne_3d, src0->nb, 3); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_element_size(dst), dst_ne_3d, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); } void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; + // Write element-wise equality (0 or 1) into a temporary buffer to avoid + // modifying src0 in-place. Use the same type as src0 so ReduceSum can + // consume it directly without a type cast. + ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0)); + size_t eq_nb[GGML_MAX_DIMS]; + eq_nb[0] = ggml_element_size(src0); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1]; + } + acl_tensor_ptr acl_eq = ggml_cann_create_tensor( + eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + src0->ne, eq_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1); - - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get()); - - ggml_cann_sum(ctx, dst); + GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get()); + + // Sum the 0/1 values into dst. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + int64_t dims[4] = { 0, 1, 2, 3 }; + acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true, + ggml_cann_type_mapping(dst->type), acl_dst.get()); } void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -3306,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get()); } +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + float beta_val = 1.0f; + float threshold_val = 20.0f; + acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT); + acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get()); +} + +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst); +} + /** * @brief Performs expert-specific matrix multiplication (MoE) with * floating-point precision using the CANN backend. @@ -3892,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst } static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // weight - ggml_tensor * src1 = dst->src[1]; // input + ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03] + ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13] GGML_TENSOR_BINARY_OP_LOCALS - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T. + // + // ggml_cann_create_tensor reverses dimension order, so ACL sees: + // acl_src0 slice: ggml[m,K] -> ACL[K,m] + // acl_src1 slice: ggml[n,K] -> ACL[K,n] + // acl_dst slice: ggml[m,n] -> ACL[n,m] + // + // Build a transposed view of src1 by swapping ne[0]/ne[1]: + // src1_t: ggml[K,n] (swapped strides) -> ACL[n,K] + // + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓ + // + // The outer batch loop is kept because src0 may have fewer batch slices than + // dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported + // by standard CANN Matmul broadcasting. + + const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type); + const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type); + const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type); + const size_t src0_type_sz = ggml_type_size(src0->type); + const size_t src1_type_sz = ggml_type_size(src1->type); + const size_t dst_type_sz = ggml_type_size(dst->type); const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { const int64_t i02 = i2 / dps2; const int64_t i03 = i3 / dps3; - const int64_t i12 = i2; - const int64_t i13 = i3; - acl_tensor_ptr accumulator = - ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - // The outer product needs to be accumulated in this dimension. - for (int64_t i1 = 0; i1 < ne11; i1++) { - acl_tensor_ptr acl_input = ggml_cann_create_tensor( - (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src1->ne, src1->nb, 1); - - acl_tensor_ptr acl_weight = ggml_cann_create_tensor( - (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, 1); - - ggml_cann_pool_alloc output_allocator(ctx.pool()); - void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); - acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); - float alpha_value = 1.0f; - aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); - } + // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m] + int64_t src0_ne[2] = { ne00, ne01 }; + size_t src0_nb[2] = { nb00, nb01 }; + acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor( + (char *) src0->data + i02 * nb02 + i03 * nb03, + src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2); + + // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K] + int64_t src1_t_ne[2] = { ne11, ne10 }; + size_t src1_t_nb[2] = { nb11, nb10 }; + acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor( + (char *) src1->data + i2 * nb12 + i3 * nb13, + src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2); + + // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m] + int64_t dst_ne[2] = { ne0, ne1 }; + size_t dst_nb[2] = { nb0, nb1 }; + acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor( + (char *) dst->data + i2 * nb2 + i3 * nb3, + dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2); + + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓ + GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, + acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1); } } } @@ -4170,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * } } } + diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 7f5ba4d3302..cdbf9260f85 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -32,6 +32,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -47,6 +50,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -69,6 +75,9 @@ */ void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate); + /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN * backend. @@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Computes the cumulative sum of a ggml tensor along dim 0 using the + * CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`. + */ +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Computes a triangular mask (tril/triu) of a square ggml tensor + * using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_TRI`. + */ +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Solves a triangular linear system AX=B using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`. + */ +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Creates a diagonal matrix from a vector using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_DIAG`. + */ +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Fills a tensor with a constant scalar value using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_FILL`. + */ +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. @@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * // @see ggml_cann_dup. void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst); +// @see ggml_cann_acc, but copies src1 into dst instead of adding. +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the softmax activation with optional masking. * @@ -813,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst); * dst->op is expected to be `GGML_OP_STEP`. */ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Performs the Flash Attention extended operator using the CANN backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 5fc484b342b..3618ba7f6f6 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1428,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, return false; } +/** + * @brief Set a region of a tensor's device memory to a specified value. + * + * @param buffer The CANN buffer containing the tensor. + * @param tensor Pointer to the tensor whose memory will be set. + * @param value The value to which each byte in the region will be set. + * @param offset Byte offset within the tensor's data to start setting. + * @param size Number of bytes to set. + */ +static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; + + ggml_cann_set_device(ctx->device); + ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size)); +} + /** * @brief Clear a CANN buffer by setting all its memory to a specified value. * @@ -1454,7 +1470,7 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, /* .set_tensor_2d = */ NULL, @@ -1835,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_UNARY_OP_STEP: ggml_cann_step(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cann_softplus(ctx, dst); + break; default: return false; } @@ -1845,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg GGML_CANN_CALL_OP_UNARY_GATED(Relu); break; case GGML_GLU_OP_GEGLU: + ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh + break; case GGML_GLU_OP_GEGLU_ERF: - // aclnnGelu internally uses the erf-based approximation. - GGML_CANN_CALL_OP_UNARY_GATED(Gelu); + ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf break; case GGML_GLU_OP_SWIGLU: - GGML_CANN_CALL_OP_UNARY_GATED(Silu); + ggml_cann_swiglu(ctx, dst); break; case GGML_GLU_OP_GEGLU_QUICK: - { - auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_op_unary_gated(lambda, ctx, dst); - } + ggml_cann_geglu_quick(ctx, dst); break; default: return false; @@ -1920,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_CPY: ggml_cann_cpy(ctx, dst); break; + case GGML_OP_SET: + ggml_cann_set(ctx, dst); + break; case GGML_OP_CONT: ggml_cann_dup(ctx, dst); break; @@ -1989,6 +2007,21 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_SSM_CONV: ggml_cann_ssm_conv(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cann_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cann_tri(ctx, dst); + break; + case GGML_OP_FILL: + ggml_cann_fill(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_cann_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_cann_solve_tri(ctx, dst); + break; default: return false; } @@ -2324,6 +2357,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, if (use_cann_graph) { // If no matching graph is found, the graph needs to be recaptured. graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph); + if (graph_capture_required) { // If no matching graph is found, add a new ACL graph. ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); @@ -2382,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_SOFTPLUS: return true; default: return false; @@ -2572,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + case GGML_OP_SET: case GGML_OP_GROUP_NORM: return true; case GGML_OP_PAD: @@ -2649,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten } case GGML_OP_SSM_CONV: return true; + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_DIAG: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32; default: return false; } From ca624d86abdbb9f332850227fc02d4b2f6d4f10e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Tue, 28 Apr 2026 08:56:02 +0200 Subject: [PATCH 208/249] ggml : revert to -lm linking instead of find_library (llama/22355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml : revert to -lm linking instead of find_library `find_library(MATH_LIBRARY m)` was introduced recently, but it breaks CUDA compilation with GGML_STATIC. I could not find any valid use case where we would prefer `find_library` over the standard `-lm` approach. This commit is also meant to start a discussion if there is a valid reason to keep `find_library(MATH_LIBRARY m)`, we should clarify what problem it was solving and find an alternative fix that does not break CUDA with GGML_STATIC. Signed-off-by: Adrien Gallouët * ggml : use MATH_LIBRARY only if defined Signed-off-by: Adrien Gallouët * ggml : fix initial broken condition Signed-off-by: Adrien Gallouët * ggml : always respect MATH_LIBRARY when defined Signed-off-by: Adrien Gallouët --------- Signed-off-by: Adrien Gallouët --- ggml/src/CMakeLists.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 52754e1b9d6..3e48860bfc8 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -470,11 +470,10 @@ endforeach() target_link_libraries(ggml-base PRIVATE Threads::Threads) -find_library(MATH_LIBRARY m) -if (MATH_LIBRARY) - if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) - endif() +if (DEFINED MATH_LIBRARY) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) +elseif (NOT WIN32 AND NOT DEFINED ENV{ONEAPI_ROOT}) + target_link_libraries(ggml-base PRIVATE m) endif() if (CMAKE_SYSTEM_NAME MATCHES "Android") From 6fceff2eb4b248e57f69b4ed6d1cf82a471ad493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Tue, 28 Apr 2026 09:02:32 +0200 Subject: [PATCH 209/249] ggml : skip already registered backends and devices (llama/22296) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- ggml/src/ggml-backend-reg.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 0587109212e..8165ae2c8bb 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -181,6 +181,12 @@ struct ggml_backend_registry { return; } + for (auto & entry : backends) { + if (entry.reg == reg) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); @@ -192,6 +198,12 @@ struct ggml_backend_registry { } void register_device(ggml_backend_dev_t device) { + for (auto & dev : devices) { + if (dev == device) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); #endif From 0fa31f9bb612e55b92b6877d729b947c9e6db4e0 Mon Sep 17 00:00:00 2001 From: Emil Askerov <56842174+EmilAskerov@users.noreply.github.com> Date: Tue, 28 Apr 2026 13:19:06 +0300 Subject: [PATCH 210/249] ggml: improve SPIR-V headers detection with __has_include (llama/21918) * ggml: improve SPIR-V headers detection with __has_include while preserving original _WIN32 logic * Address review comments: fix fallback logic and add FreeBSD support * Remove spirv_cross fallback as per review * Remove redundant __has_include check --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d4acee8b1df..6256639ab97 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -20,12 +20,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include -// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and -// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan. -#if defined(_WIN32) && !defined(__MINGW32__) -#include + +// SPIR-V Headers: different SDK installations expose different include paths. +// LunarG Vulkan SDK on Windows typically provides . +// Linux packages, MSYS2 and MinGW often use the Khronos layout . +#if __has_include() +# include +#elif __has_include() +# include +#elif __has_include() +# include #else -#include + // Fallback to let the compiler throw a standard "file not found" error +# include #endif #include From 35fa508360cf2baee08c5eeb7b78c01bc79af000 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 28 Apr 2026 12:28:12 +0200 Subject: [PATCH 211/249] vulkan: add barrier after writetimestamp (llama/21865) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6256639ab97..69c24bb5877 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13014,6 +13014,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (vk_perf_logger_enabled && vk_perf_logger_concurrent) { ctx->query_node_idx[ctx->query_idx] = node_idx; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } } // Add all fused nodes to the unsynchronized lists. @@ -14503,6 +14504,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } ctx->prealloc_y_last_pipeline_used = nullptr; @@ -14739,6 +14741,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; ctx->query_fusion_names[ctx->query_idx] = fusion_string; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } else { // track a fusion string and number of fused ops for the current node_idx ctx->query_fusion_names[i] = fusion_string; From 4ea5b6febcbab8c100da752d8b826afbcfec1382 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 28 Apr 2026 07:27:17 -0700 Subject: [PATCH 212/249] ggml-webgpu: fix buffer aliasing for ssm_scan and refactor aliasing logic (llama/22456) * Refactor buffer aliasing to be part of shader lib decisions * cleanup * formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 159 ++++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 335 +++++++++--------- ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 6 +- .../wgsl-shaders/rms_norm_mul.wgsl | 6 +- .../ggml-webgpu/wgsl-shaders/ssm_scan.wgsl | 25 ++ 5 files changed, 301 insertions(+), 230 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index fb2c9527f3c..34cbf3694b1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -26,21 +26,21 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 4 -#define WEBGPU_MUL_MAT_TILE_N 4 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 4 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 @@ -59,19 +59,32 @@ template inline void ggml_webgpu_hash_combine(size_t & seed, const seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +// Calculates base address of a tensor ignoring the fake base pointer +inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (uintptr_t) base_tensor->data + tensor->view_offs; +} + +inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b); +} + +inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) && + ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a); +} + struct ggml_webgpu_shader_lib_context { ggml_tensor * src0; ggml_tensor * src1; ggml_tensor * src2; ggml_tensor * src3; ggml_tensor * src4; + ggml_tensor * src5; ggml_tensor * dst; uint32_t max_wg_size; size_t wg_mem_limit_bytes = 0; - bool inplace = false; - bool overlap = false; - bool src_overlap = false; bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -88,6 +101,14 @@ struct webgpu_pipeline { struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; + bool inplace = false; +}; + +struct ggml_webgpu_binary_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; }; struct ggml_webgpu_processed_shader { @@ -102,11 +123,12 @@ struct ggml_webgpu_ssm_conv_shader_decisions { }; struct ggml_webgpu_ssm_scan_pipeline_key { - int type; - int d_state; + int type; + int d_state; + bool xbc_overlap; bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { - return type == other.type && d_state == other.d_state; + return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap; } }; @@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.d_state); + ggml_webgpu_hash_combine(seed, key.xbc_overlap); return seed; } }; @@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash { struct ggml_webgpu_ssm_scan_shader_decisions { uint32_t wg_size; uint32_t tokens_per_tile; + bool xbc_overlap = false; }; /** Argsort **/ @@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { } }; +struct ggml_webgpu_rms_norm_mul_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -503,11 +534,12 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { }; struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; + bool kv_overlap = false; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; @@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; - key.kv_overlap = context.src_overlap; + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; @@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = {}; key.op = context.dst->op; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -1051,8 +1083,12 @@ class ggml_webgpu_shader_lib { const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - auto processed = preprocessor.preprocess(wgsl_row_norm, defines); - row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->inplace = key.inplace; + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + row_norm_pipelines[key].context = decisions; return row_norm_pipelines[key]; } @@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_set_pipeline_key key = {}; key.type = context.dst->type; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_set, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; set_pipelines[key] = pipeline; @@ -1355,7 +1392,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_scale_pipeline_key key = {}; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1375,6 +1412,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_scale, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; scale_pipelines[key] = pipeline; @@ -1468,6 +1506,8 @@ class ggml_webgpu_shader_lib { ggml_webgpu_ssm_scan_pipeline_key key = {}; key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; + key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1499,12 +1539,17 @@ class ggml_webgpu_shader_lib { variant += "_wg_reduce"; } + if (key.xbc_overlap) { + defines.push_back("XBC_OVERLAP"); + } + variant += "_d" + std::to_string(key.d_state); auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; decisions->tokens_per_tile = tokens_per_tile; + decisions->xbc_overlap = key.xbc_overlap; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; ssm_scan_pipelines[key] = pipeline; @@ -1764,11 +1809,9 @@ class ggml_webgpu_shader_lib { uint32_t tile_k; if (key.use_subgroup_matrix) { - tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT - : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; } else { - tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT - : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; } // Tiles @@ -2001,9 +2044,8 @@ class ggml_webgpu_shader_lib { defines.push_back("SCALAR"); // mul_mat_id is register-tile only. - const uint32_t tile_k = ggml_is_quantized(context.src0->type) - ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT - : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + const uint32_t tile_k = + ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); @@ -2039,8 +2081,8 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.op = op; key.is_unary = is_unary; - key.inplace = context.inplace; - key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -2098,6 +2140,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_unary, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; unary_pipelines[key] = pipeline; @@ -2106,9 +2149,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rms_norm_mul_pipeline_key key = {}; - key.inplace = context.inplace; - key.overlap = context.overlap; - key.src_overlap = context.src_overlap; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = rms_norm_mul_pipelines.find(key); if (it != rms_norm_mul_pipelines.end()) { @@ -2132,12 +2175,15 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - rms_norm_mul_pipelines[key] = pipeline; + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + rms_norm_mul_pipelines[key] = pipeline; return rms_norm_mul_pipelines[key]; } @@ -2145,9 +2191,9 @@ class ggml_webgpu_shader_lib { ggml_webgpu_binary_pipeline_key key = {}; key.type = context.dst->type; key.op = context.dst->op; - key.inplace = context.inplace; - key.overlap = context.overlap; - key.src_overlap = context.src_overlap; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -2186,11 +2232,15 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_binary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; + pipeline.context = pipeline_decisions; binary_pipelines[key] = pipeline; return binary_pipelines[key]; } @@ -2351,7 +2401,8 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - auto pipeline_decisions = std::make_shared(decisions); + auto pipeline_decisions = std::make_shared(decisions); + pipeline_decisions->kv_overlap = key.kv_overlap; defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); @@ -2543,7 +2594,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rope_pipeline_key key = {}; key.type = context.dst->type; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); @@ -2582,6 +2633,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_rope, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; rope_pipelines[key] = pipeline; @@ -2593,7 +2645,7 @@ class ggml_webgpu_shader_lib { key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; key.has_mask = (context.src1 != nullptr); key.has_sink = (context.src2 != nullptr); - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2634,6 +2686,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_soft_max, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; soft_max_pipelines[key] = pipeline; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 6d861c0c781..762d9f8d1b4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) { // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT -// Always returns the base offset of a tensor, regardless of views. -static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { - if (tensor->view_src) { - return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; - } - return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs; } /* Struct definitions */ @@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; return ctx->buffer; @@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); } -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} +struct ggml_webgpu_merged_binding_range { + size_t offset; + size_t size; +}; -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} +static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range( + webgpu_context & ctx, + std::initializer_list tensors) { + size_t merged_offset = SIZE_MAX; + size_t merged_end = 0; -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; + for (ggml_tensor * tensor : tensors) { + const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor); + const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor); -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + merged_offset = std::min(merged_offset, bind_offset); + merged_end = std::max(merged_end, bind_end); + } + + return { merged_offset, merged_end - merged_offset }; +} - return flags; +static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor, + const ggml_webgpu_merged_binding_range & merged_range) { + return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type)); } static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, @@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); @@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.src5 = src5; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + const bool xbc_overlap = decisions->xbc_overlap; + + uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)); + uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)); + size_t xbc_bind_offset = 0; + size_t xbc_bind_size = 0; + if (xbc_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 }); + xbc_bind_offset = merged_range.offset; + xbc_bind_size = merged_range.size; + offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range); + offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_x, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), + offset_B, + offset_C, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, }; std::vector entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; + if (xbc_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst)); + } const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1653,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); const int has_mask = (mask != nullptr); const int has_sinks = (sinks != nullptr); - const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; + const bool kv_overlap = decisions->kv_overlap; uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); size_t kv_bind_offset = 0; size_t kv_bind_size = 0; if (kv_overlap) { - const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); - const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); - const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); - const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); - kv_bind_offset = std::min(k_bind_offset, v_bind_offset); - kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset; - offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type)); - offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type)); + const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); + kv_bind_offset = merged_range.offset; + kv_bind_size = merged_range.size; + offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); + offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } std::vector params = { @@ -1720,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, } entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.src_overlap = kv_overlap; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast(pipeline.context.get()); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches @@ -1921,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; - bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1994,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = flags.inplace; - shader_lib_ctx.overlap = flags.overlap; - shader_lib_ctx.src_overlap = flags.src_overlap; - - webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); - uint32_t offset_merged_src0 = 0; - uint32_t offset_merged_src1 = 0; - if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); - offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); } std::vector params = { ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_src0, + offset_src1, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - offset_merged_src0, - offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), @@ -2048,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - if (flags.src_overlap) { - size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), - src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, - merged_end - merged_offset)); + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), @@ -2062,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), src1_webgpu_tensor_align_offset, ggml_webgpu_tensor_binding_size(ctx, src1))); - if (!flags.inplace && !flags.overlap) { + if (!decisions->inplace && !decisions->overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2168,29 +2177,15 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || - (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); - bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); - bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); - - uint32_t offset_merged_rn_src = 0; - uint32_t offset_merged_mul_src = 0; - size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); - size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); - - if (src_overlap) { - size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); - offset_merged_rn_src = - (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); - offset_merged_mul_src = - (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); - } + uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)); + uint32_t offset_mul_src = + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)); + size_t merged_offset = 0; + size_t merged_size = 0; std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), - offset_merged_rn_src, - offset_merged_mul_src, + offset_rn_src, + offset_mul_src, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), @@ -2214,16 +2209,32 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context std::vector entries; - if (inplace || overlap) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = rn_src; + shader_lib_ctx.src1 = mul_src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range); + offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range); + params[0] = offset_rn_src; + params[1] = offset_mul_src; + } + + if (decisions->inplace || decisions->overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); - } else if (src_overlap) { - size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); - size_t merged_end = - std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), - mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); - entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, - merged_end - merged_offset)); + } else if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); @@ -2231,20 +2242,10 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; - shader_lib_ctx.overlap = overlap; - shader_lib_ctx.src_overlap = src_overlap; - - webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2261,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); - } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; - webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } @@ -2287,14 +2288,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, shader_lib_ctx.src2 = src2; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_freq_factor = (src2 != nullptr); + const bool inplace = decisions->inplace; + const int has_freq_factor = (src2 != nullptr); const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2421,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, } static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2454,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s // bindgroups unchanged std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { + if (!decisions->inplace) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } @@ -2473,17 +2470,17 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, shader_lib_ctx.src2 = src2; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); - webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias = ggml_get_op_params_f32(dst, 1); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool inplace = decisions->inplace; + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -3079,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, size_t size) { GGML_UNUSED(backend); auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // Write aligned portion buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -3161,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; @@ -3180,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -3212,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, << ", " << offset << ", " << size << ")"); wgpu::Device device = buf_ctx->global_ctx->device; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; size_t final_size = size; if (size % 4 != 0) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index a748dc1b86c..605de7aa7be 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,8 +7,6 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, - offset_merged_src0: u32, - offset_merged_src1: u32, stride_src0_0: u32, stride_src0_1: u32, @@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + let src0_i = params.offset_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + src1_index(gid.x); update(params.offset_dst + gid.x, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl index 74aaa2753ae..fd20a4e54c9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) struct Params { offset_rn_src: u32, offset_mul_src: u32, - offset_merged_rn_src: u32, - offset_merged_mul_src: u32, offset_dst: u32, stride_rn_src1: u32, @@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3, i = i % (params.ne2 * params.ne1); let i2 = i / params.ne1; let i1 = i % params.ne1; - let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; - let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl index 64324738591..05761dec353 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -45,6 +45,14 @@ struct Params { }; @group(0) @binding(0) var s_in: array; +#ifdef XBC_OVERLAP +@group(0) @binding(1) var x_B_C_merged: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var ids: array; +@group(0) @binding(5) var dst: array; +@group(0) @binding(6) var params: Params; +#else @group(0) @binding(1) var x: array; @group(0) @binding(2) var dt: array; @group(0) @binding(3) var A: array; @@ -53,6 +61,7 @@ struct Params { @group(0) @binding(6) var ids: array; @group(0) @binding(7) var dst: array; @group(0) @binding(8) var params: Params; +#endif var shared_x_dt: array; var shared_dtsp: array; @@ -98,7 +107,11 @@ fn main( let dt0 = dt[dt_idx]; let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); shared_dtsp[tid] = dtsp; +#ifdef XBC_OVERLAP + shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp; +#else shared_x_dt[tid] = x[x_idx] * dtsp; +#endif } } @@ -116,16 +129,28 @@ fn main( let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; +#ifdef XBC_OVERLAP + let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt; +#else let s = s_prev * dA + B[b_idx] * x_dt; +#endif s_prev = s; #ifdef USE_SUBGROUP_REDUCTION +#ifdef XBC_OVERLAP + let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]); +#else let subgroup_partial = subgroupAdd(s * C[c_idx]); +#endif if (subgroup_invocation_id == 0u) { shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; } +#else +#ifdef XBC_OVERLAP + shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx]; #else shared_reduce[reduce_idx] = s * C[c_idx]; +#endif #endif workgroupBarrier(); From e69c109aac3f7ca1643a50027603902c123a3849 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:31:04 +0000 Subject: [PATCH 213/249] vulkan: Coalesce Q4_K/Q5_K scale loads (llama/21751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some SPIR-V compilers (notably mesa) don't handle the current vulkan Q4_K/Q5_K scale load pattern in mul_mat particularly well. While reading three `u8`s from the 12-byte scale array should (at least on some hardware) result in loading the full 12 bytes in a single LOAD followed by whatever extraction is needed, at least the ANV Intel driver really can't practically perform this optimization. `mesa`'s unsigned upper bound logic doesn't handle tracking bounds through ternary, resulting in the `(is < 4) ? ... : is - 4` having an infinite upper bound (as it cannot prove `is - 4` doesn't underflow). While this could still be rectified if mesa looked at the array bounds, it currently doesn't and `glslc` currently emits SPIR-V that doesn't allow for this optimization anyway (though maybe it will at some point, see https://github.com/KhronosGroup/glslang/issues/4206). In mul_mat_vecq we took a different approach to loading the same fields. We read the first two bytes we needed from `scale` then took a branch before deciding whether we needed to read a third byte. In mesa this did, indeed, lead to a top-level branch with conditional loads. As such these loads ended up not being coalesced either (at least in the ANV driver) resulting in additional instructions in our hot loop. Instead, here, we go ahead and force loading the full 12 bytes and extract the bits we need from the packed-u32s instead. In mul_mat there's a few less ternaries and only one extra shift, so even on drivers that did optimize the previous loads properly the only material change should be pulling a few extra bytes into registers (which on most hardware won't cost anything anyway, though ironically on Intel it theoretically could). In mul_mat_vecq this requires a bit of extra math and may read bytes from the u32 that weren't needed, but it seems likely avoiding the branch is a win on most platforms. On Intel Xe2/mesa 26.0.4 with the optimizations from https://gitlab.freedesktop.org/mesa/mesa/-/work_items/15162, for shader matmul_id_subgroup_q4_k_f32_f16acc_aligned_l: * Instruction Count: 2753 -> 2688 * SEND Count: 269 -> 261 * Cycle Count: 273976 -> 266138 * Max live registers: 248 -> 246 * Non SSA regs after NIR: 381 -> 382 for shader matmul_id_subgroup_q5_k_f32_f16acc_aligned_l: * Instruction Count: 2767 -> 2702 * SEND Count: 271 -> 263 * Cycle Count: 274140 -> 268144 * Max live registers: 248 -> 246 * Non SSA regs after NIR: 381 -> 382 for shader mul_mat_vec_id_q4_k_q8_1_f32: * Instruction Count: 1930 -> 1646 * SEND Count: 116 -> 71 * Cycle Count: 1348306 -> 843350 * Max live registers: 78 -> 84 * Non SSA regs after NIR: 300 -> 135 for shader mul_mat_vec_id_q5_k_q8_1_f32: * Instruction Count: 2207 -> 1922 * SEND Count: 131 -> 86 * Cycle Count: 1392012 -> 1037836 * Max live registers: 90 -> 90 * Non SSA regs after NIR: 300 -> 135 for shader mul_mat_vec_q4_k_q8_1_f32: * Instruction Count: 2029 -> 1749 * SEND Count: 111 -> 66 * Cycle Count: 1347278 -> 840118 * Max live registers: 74 -> 80 * Non SSA regs after NIR: 299 -> 134 for shader mul_mat_vec_q5_k_q8_1_f32: * Instruction Count: 2307 -> 2022 * SEND Count: 126 -> 81 * Cycle Count: 1379820 -> 954042 * Max live registers: 86 -> 86 * Non SSA regs after NIR: 299 -> 134 On one Arc Pro B60, unsloth/Qwen3.5-35B-A3B-GGUF:UD-Q4_K_XL: * pp512: 907.34 ± 9.28 -> 941.94 ± 10.53 (+4%) * pp2048: 897.95 ± 1.82 -> 931.55 ± 1.79 (+4%) * tg128: 49.49 ± 0.02 -> 49.86 ± 0.05 (+ <1%) On one Arc Pro B60, unsloth/Qwen3.5-27B-GGUF:Q4_K_S: * pp512: 324.13 ± 10.52 -> 354.33 ± 6.81 (+9%) * pp2048: 329.80 ± 0.25 -> 357.10 ± 0.06 (+8%) * tg128: 17.11 ± 0.01 -> 18.11 ± 0.01 (+6%) On four Arc Pro B60s, unsloth/Qwen3.5-122B-A10B-GGUF:Q5_K_S with -sm layer (note that -sm tensor improvements will naturally be less): * pp512: 264.55 ± 2.81 -> 280.45 ± 3.94 (+6%) * pp2048: 319.32 ± 2.72 -> 335.70 ± 3.48 (+5%) * tg128: 26.39 ± 0.01 -> 26.67 ± 0.01 (+1%) --- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 23 +++++--- .../vulkan-shaders/mul_mm_funcs.glsl | 54 ++++++++++--------- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index e99108dc50c..bc580aeeb83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; const uint is = iqs_k / 8; - u8vec2 scale_dm; - if (is < 4) { - scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); - } else { - scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), - (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); - } + + const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0], + data_a_packed32[ib_k].scales[1], + data_a_packed32[ib_k].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); + u8vec2 scale_dm = u8vec2(sc, mbyte); return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 6e4a29d2fdd..73595168984 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; From b553e17071862cd10feed563f8afb027cc713e18 Mon Sep 17 00:00:00 2001 From: lnigam Date: Wed, 29 Apr 2026 01:07:35 +0530 Subject: [PATCH 214/249] =?UTF-8?q?ggml-cuda:=20add=20flash-attn=20support?= =?UTF-8?q?=20for=20DKQ=3D320/DV=3D256=20with=20ncols2=3D32=20(=E2=80=A6?= =?UTF-8?q?=20(#22286)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32) Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only. * Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32 * Apply suggestions from code review Address review comments Co-authored-by: Johannes Gäßler * Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256 * Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np) * Apply suggestions from code review Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 +++++++- ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 37 ++++++++++++++----- ggml/src/ggml-cuda/fattn.cu | 24 ++++++++++++ ...ttn-mma-f16-instance-ncols1_1-ncols2_32.cu | 1 + ...ttn-mma-f16-instance-ncols1_2-ncols2_32.cu | 1 + .../fattn-tile-instance-dkq320-dv256.cu | 5 +++ .../template-instances/generate_cu_files.py | 15 +++++--- 8 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e185449d491..3f01e858de7 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); @@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build: +extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); + // For GLM 4.7 Flash extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 25b16e83cac..d60634cc0e9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 320: { + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst); + } break; case 512: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 26721cc4c7d..928b856f9d2 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) @@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) @@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } } if constexpr (ncols2 <= 8) { @@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + if constexpr (DKQ == 320) { // Mistral Small 4 + if (use_gqa_opt && gqa_ratio % 32 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); + } + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); @@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512) { + if constexpr (DKQ <= 512 && DKQ != 320) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd337..8256591b21d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 320: + // For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build). + GGML_ASSERT(V->ne[0] == 256); + { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 32 == 0); + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst); + } + break; case 512: GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); @@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 320: + if (V->ne[0] != 256 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 32 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu index 1f554d81e5e..8fc3b17976e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu index 264751d65ec..abd2b21ce04 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu new file mode 100644 index 00000000000..c91f508079d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(320, 256); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 841059c15b5..5e9a1cb2eb3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +62,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -84,13 +84,16 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq == 512 and ncols2 not in (4, 8): + # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue - if head_size_kq != 576 and ncols2 in (16, 32): + if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue - if head_size_kq == 576 and ncols2 not in (4, 16, 32): + if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - head_size_v = head_size_kq if head_size_kq != 576 else 512 + if head_size_kq not in (320, 576) and ncols2 in (16, 32): + continue + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: From c200b588f88301ab77f8f368355ed718ecb18ce7 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Tue, 28 Apr 2026 15:47:42 -0700 Subject: [PATCH 215/249] ggml-cuda: Repost of 21896: Blackwell native NVFP4 support (llama/22196) --- ggml/src/ggml-cuda/common.cuh | 12 ++ ggml/src/ggml-cuda/mma.cuh | 34 +++-- ggml/src/ggml-cuda/mmq.cu | 21 ++- ggml/src/ggml-cuda/mmq.cuh | 230 ++++++++++++++++++++------------ ggml/src/ggml-cuda/mmvq.cu | 3 + ggml/src/ggml-cuda/quantize.cu | 148 ++++++++++++++++---- ggml/src/ggml-cuda/quantize.cuh | 2 +- 7 files changed, 319 insertions(+), 131 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3aec1742ee1..10817505d9f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { #endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 } +static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) { +#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only + if (!(x > 0.0f)) { + return 0; + } + const __nv_fp8_e4m3 xf(x); + return xf.__x; +#else + NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index b0f674635f1..79bb2934c5f 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma { #endif // AMD_MFMA_AVAILABLE } - static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, - const tile<16, 8, int> & A, - const tile<8, 8, int> & B, - uint32_t a_scale, - uint32_t b_scale) { + template + static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { #ifdef BLACKWELL_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; float * Dxi = (float *) D.x; - asm volatile( - "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " - "%10, {0, 0}, %11, {0, 0};" - : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) - : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + if constexpr (type == GGML_TYPE_MXFP4) { + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } else { + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } #else GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); -#endif // BLACKWELL_MMA_AVAILABLE +#endif // BLACKWELL_MMA_AVAILABLE } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3f01ff5bfb0..e1add5e0331 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q( || GGML_CUDA_CC_IS_CDNA(cc); // TODO: tighter pool buffer size vs q8 path - const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { + if (use_native_fp4) { static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); - quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); } else { @@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q( } // Stride depends on quantization format - const int64_t s12 = use_native_mxfp4 ? - ne11 * ne10_padded * sizeof(block_fp4_mmq) / - (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) - : + const int64_t s12 = use_native_fp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; @@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q( const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { - quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + if (use_native_fp4) { + quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); } else { quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, @@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q( CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : - ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); + static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4"); + const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 91a1b737a82..edf546d8f1e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -10,9 +10,9 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -#define MMQ_ITER_K 256 -#define MMQ_ITER_K_MXFP4_FP4 512 -#define MMQ_NWARPS 8 +#define MMQ_ITER_K 256 +#define MMQ_ITER_K_FP4 512 +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); @@ -46,9 +46,12 @@ struct block_q8_1_mmq { int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; +// this struct is used for fp4 data types (currently only used for Blackwell) +// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits +// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales struct block_fp4_mmq { - uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. - int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values + uint32_t d4[4]; + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte) }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); @@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) { static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { #if defined(BLACKWELL_MMA_AVAILABLE) - return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; -#else - return MMQ_ITER_K; +if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) { + return MMQ_ITER_K_FP4; +} #endif // defined(BLACKWELL_MMA_AVAILABLE) + return MMQ_ITER_K; } static constexpr __device__ int get_mmq_y_device() { @@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 -#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) @@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; +#if defined(BLACKWELL_MMA_AVAILABLE) + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4; +#else case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; +#endif // defined(BLACKWELL_MMA_AVAILABLE) case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } +#ifdef BLACKWELL_MMA_AVAILABLE +template +static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4); + constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + + uint32_t * x_u32 = (uint32_t *) x_tile; + + const int txi = threadIdx.x; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + + const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx; + uint32_t * x_u32_scale = x_u32 + 64 + kbx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = bxi_base + i * stride; + const int row_base = i * MMQ_MMA_TILE_X_K_FP4; + const int q_base = row_base + 8 * kbx; + + const uint32_t * src_qs = reinterpret_cast(bxi->qs); + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0]; + x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1]; + } + + x_u32_scale[row_base] = get_int_b4(bxi->d, 0); + } +} + +// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell. +// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per +// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3) +// and the per-type stride constant differ. +template +static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4, + "vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4"); + + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int stride = MMQ_MMA_TILE_X_K_FP4; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; + constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J; + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // 2 threads per quad supply the packed scale register to the block_scale MMA, + // see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + const int tidx_B = threadIdx.x / 4; + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + tile_A A[ntx][nfrags]; + uint32_t scaleA[ntx][nfrags]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = k00 + frag * tile_A::J; + load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride); + scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { + tile_B B[nfrags]; + uint32_t scaleB[nfrags]; + +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = frag * tile_B::J; + load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); + scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + tile_C C = {}; + mma_block_scaled_fp4(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} +#endif // BLACKWELL_MMA_AVAILABLE + template static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, @@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -template -static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, - const int * __restrict__ y, - float * __restrict__ sum, - const int k00) { - typedef tile<16, 8, int> tile_A; - typedef tile<8, 8, int> tile_B; - typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); - - // Match layout from load_tiles_mxfp4_fp4 - const int * x_qs = (const int *) x; - const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); - const int * y_qs = (const int *) y + 4; - const uint32_t * y_sc = (const uint32_t *) y; - - // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 - tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - - // Block scale - // Each thread has to point to a 4 byte scale value - // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, - MMQ_MMA_TILE_X_K_FP4); - - // based on block-scaling document, 2 threads in each quad need to supply to the scale value - const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; - scaleA[n][k01 / (2 * QI_MXFP4)] = - *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - tile_B B; - uint32_t scaleB; // 2xN scales - - load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); - - scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - - mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; - } - } - } - } -} template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( @@ -3259,7 +3318,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; #ifdef BLACKWELL_MMA_AVAILABLE static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma; #else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3270,8 +3329,13 @@ struct mmq_type_traits { template struct mmq_type_traits { static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma; +#else static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; +#endif // BLACKWELL_MMA_AVAILABLE static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; @@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( #if defined(BLACKWELL_MMA_AVAILABLE) // FP4 tile stores 8 blocks - constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; + constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1; #else constexpr int ne_block = 4 * QK8_1; #endif // defined(BLACKWELL_MMA_AVAILABLE) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8f55cace1a1..da48f313a38 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg case GGML_TYPE_IQ4_NL: return 6; case GGML_TYPE_IQ4_XS: return 5; case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_NVFP4: return 4; case GGML_TYPE_Q2_K: return 4; case GGML_TYPE_Q3_K: return 4; case GGML_TYPE_Q4_0: return 6; @@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm case GGML_TYPE_IQ3_S: return 6; case GGML_TYPE_IQ3_XXS: return 7; case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_NVFP4: return 8; case GGML_TYPE_Q2_K: return 7; case GGML_TYPE_Q3_K: return 5; default: return MMVQ_MAX_BATCH_SIZE; @@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type case GGML_TYPE_IQ4_NL: return 7; case GGML_TYPE_IQ4_XS: return 5; case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_NVFP4: return 5; case GGML_TYPE_Q3_K: return 4; case GGML_TYPE_Q4_0: return 7; case GGML_TYPE_Q4_1: return 7; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 4300ffc148c..52f664719ae 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { return static_cast(biased); } + +static __global__ void quantize_mmq_nvfp4( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2) { +#if defined(BLACKWELL_MMA_AVAILABLE) + + const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB; + if (i0_base >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t k_block = i0_base / QK_K; + const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K; + if (k_block >= blocks_per_col) { + return; + } + + const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x; + block_fp4_mmq * y = (block_fp4_mmq *) vy; + block_fp4_mmq * yb = y + ib; + + const int sub = (i0_base % QK_K) / QK_NVFP4_SUB; + + float vals_raw[QK_NVFP4_SUB]; + float amax_raw = 0.0f; + const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; k++) { + const int64_t i00 = i0_base + k; + if (i00 < ne00) { + const float v = x[base_idx + i00]; + vals_raw[k] = v; + amax_raw = fmaxf(amax_raw, fabsf(v)); + } else { + vals_raw[k] = 0.0f; + } + } + + static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2}; + const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f); + + float best_err = FLT_MAX; + uint8_t fp8_code = 0; + float subblock_scale = 0.0f; + +#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell. + for (int i = 0; i < 5; i++) { + const int test_code = first_fp8_code + test_offsets[i]; + if (test_code < 0 || test_code > 0x7e) { + continue; + } + const uint8_t code = (uint8_t) test_code; + const float test_scale = ggml_cuda_ue4m3_to_fp32(code); + const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f; + float cur_err = 0.0f; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; ++k) { + const float v = vals_raw[k]; + const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale); + const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale; + cur_err = fmaf(err_diff, err_diff, cur_err); + } + + if (cur_err < best_err) { + best_err = cur_err; + fp8_code = test_code; + subblock_scale = test_scale; + } + } + + const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f; + uint32_t q0 = 0; + uint32_t q1 = 0; +#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1 + for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) { + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k); + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4); + } + + uint32_t * yqs = reinterpret_cast(yb->qs); + yqs[2 * sub + 0] = q0; + yqs[2 * sub + 1] = q1; + reinterpret_cast(yb->d4)[sub] = fp8_code; +#else + NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only. +#endif // defined(BLACKWELL_MMA_AVAILABLE) + +} + // quantize values in the format mxfp4 is stored which is interleaved nibbles // i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, @@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda( } } -void quantize_mmq_mxfp4_cuda(const float * x, - const int32_t * ids, - void * vy, - [[maybe_unused]] const ggml_type type_src0, - const int64_t ne00, - const int64_t s01, - const int64_t s02, - const int64_t s03, - const int64_t ne0, - const int64_t ne1, - const int64_t ne2, - const int64_t ne3, - cudaStream_t stream) { - GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); - - constexpr int nwarps = 8; - constexpr int vals_per_warp = 2 * QK_MXFP4; - constexpr int vals_per_block = nwarps * vals_per_warp; - - const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; - const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); - const dim3 block_size(WARP_SIZE, nwarps, 1); - - quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +void quantize_mmq_fp4_cuda( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4); + GGML_ASSERT(ne0 > 0); + + if (type_src0 == GGML_TYPE_NVFP4) { + GGML_ASSERT(ne00 % QK_NVFP4 == 0); + constexpr int nvfp4_block_size = 128; + const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size); + const dim3 block_size(nvfp4_block_size, 1, 1); + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + quantize_mmq_nvfp4<<>>( + x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } else { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 6a91df63578..768a3ae6de6 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda( ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); -void quantize_mmq_mxfp4_cuda(const float * x, +void quantize_mmq_fp4_cuda(const float * x, const int32_t * ids, void * vy, ggml_type type_src0, From 53011393746fcdc9423af536fba0be02a1d66363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 29 Apr 2026 08:55:07 +0200 Subject: [PATCH 216/249] TP: fix delayed AllReduce + zero-sized slices (llama/22489) --- ggml/src/ggml-backend-meta.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 41a61775bd6..fbc02d6458a 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1826,7 +1826,24 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, continue; } - i = get_i_delayed(i); + const int i_delayed = get_i_delayed(i); + + // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices. + // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has + // its compute flag disabled and thus gets its data zeroed out. + // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled. + if (i_delayed > i) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + for (int ii = i + 1; ii <= i_delayed; ii++) { + bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + } + + i = i_delayed; for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; From 3076725eb074338c0b0fa0bb50bfc00d4aec6497 Mon Sep 17 00:00:00 2001 From: hrushitfujitsu Date: Wed, 29 Apr 2026 13:27:37 +0530 Subject: [PATCH 217/249] ggml : add sve tuned code for gemm_q8_0_4x8_q8_0() kernel (llama/21916) * Added sve tuned code for gemm_q8_0_4x8_q8_0() kernel * Change arrays to static const in repack.cpp --------- Co-authored-by: Vithulep --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 65 +++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 80ff5ce549b..a7534443091 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; + + static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + svuint32_t idx = svld1(svptrue_b32(), idx_arr); + static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0}; + svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1); + static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3}; + svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2); + + for (int y = 0; y < nr; y += 4) { + const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb; + + for (int x = 0; x < nc; x += ncols_interleaved) { + const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb; + const block_q8_0x4 * a_ptr = a_ptr_base; + + svfloat32_t acc_f32_01 = svdup_f32(0); + svfloat32_t acc_f32_23 = svdup_f32(0); + + for (int b = 0; b < nb; b++) { + + svint32_t acc_01 = svdup_s32(0); + svint32_t acc_23 = svdup_s32(0); + + // Process 4 chunks of 8 positions each + for (int chunk = 0; chunk < 4; chunk++) { + svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32); + svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16); + svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32); + + acc_01 = svmmla_s32(acc_01, s_a01, s_b0123); + acc_23 = svmmla_s32(acc_23, s_a23, s_b0123); + } + + // Reorder outputs from 2×2 tiles to row-major + // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3] + // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3] + + svint32_t row01 = svtbl_s32(acc_01, idx); + svint32_t row23 = svtbl_s32(acc_23, idx); + + svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d); + svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d); + svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1); + svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2); + + acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0)); + acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2)); + a_ptr++; + b_ptr++; + } + + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01); + svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4)); + svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23); + svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4)); + } + } + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; From fa20229eeb54ee219fe9f67782bbae799d953f2b Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 29 Apr 2026 00:59:00 -0700 Subject: [PATCH 218/249] ggml-webgpu: Fix bug in FlashAttention support check (llama/22492) * Fix flashattention support check for devices that don't support subgroups * set path to none if kv_tile doesn't fit --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 44 ++++++++++++------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 ++ 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 34cbf3694b1..b7771ac230e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -494,9 +494,10 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, }; struct ggml_webgpu_flash_attn_pipeline_key { @@ -534,7 +535,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { }; struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; @@ -709,19 +710,29 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + return decisions; + } const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); decisions.kv_direct = key.kv_direct; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + // invalidate if even the smallest kv_tile doesn't fit in shared memory + if (max_kv_tile == 0) { + decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + return decisions; + } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { @@ -734,9 +745,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) : - std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + std::min(64u, max_kv_tile) : + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); @@ -755,7 +765,6 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( context.sg_mat_n; } } - return decisions; } @@ -1364,7 +1373,7 @@ class ggml_webgpu_shader_lib { if (key.src_type == GGML_TYPE_Q1_0) { defines.push_back("BLOCK_SIZE=128u"); } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { + key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); @@ -2325,6 +2334,7 @@ class ggml_webgpu_shader_lib { size_t storage_offset_alignment) { const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 762d9f8d1b4..f7fd73ae144 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3918,6 +3918,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + supports_op = false; + break; + } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], From 6119537e9aa65c4dc117c395c2d5acac07eb6b21 Mon Sep 17 00:00:00 2001 From: qiurui144 <39214303+qiurui144@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:59:21 +0800 Subject: [PATCH 219/249] ggml-cpu: cmake: append xsmtvdotii march for SpacemiT IME (llama/22317) * ggml-cpu: cmake: append xsmtvdotii march for SpacemiT IME When GGML_CPU_RISCV64_SPACEMIT=ON is set, ime1_kernels.cpp contains inline asm for the vmadot family which requires the xsmtvdotii custom extension.(problem can see in some blogs and make sure in K3 platform) The current CMakeLists does not include xsmtvdotii, so any toolchain that honours the explicit -march (tested with SpacemiT GCC 15.2) fails at the assembler stage: Error: unrecognized opcode `vmadot v16,v14,v0', extension `xsmtvdotii' required Append _xsmtvdotii to MARCH_STR when GGML_CPU_RISCV64_SPACEMIT is enabled so the IME path can actually build with a capable toolchain. No effect on builds that leave GGML_CPU_RISCV64_SPACEMIT off. toolchain from https://www.spacemit.com/community/resources-download/Tools * Update ggml/src/ggml-cpu/CMakeLists.txt Co-authored-by: alex-spacemit --------- Co-authored-by: alex-spacemit --- ggml/src/ggml-cpu/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index beebc4760d2..c1c225f0197 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -485,6 +485,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_CPU_RISCV64_SPACEMIT) + # `xsmtvdotii' is only required for GCC >= 15. + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND MARCH_STR "_xsmtvdotii") + endif() + endif() list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() From 44e7803661cf16d648b8c0a5b250aea1167d99c1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 29 Apr 2026 16:19:33 +0800 Subject: [PATCH 220/249] ggml-cuda: refactor fusion code (llama/22468) * ggml-cuda: refactor fusion code * apply formatting + make env variable truthy --- ggml/src/ggml-cuda/ggml-cuda.cu | 703 ++++++++++++++++---------------- 1 file changed, 355 insertions(+), 348 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1c2c3b4ac69..fd8dd91714c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3640,6 +3640,357 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } +// try and fuse nodes and return the number of nodes to skip +static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) { + + static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION")); + if (disable_fusion) { + return 0; + } + + ggml_tensor * node = cgraph->nodes[i]; + + //topk-moe + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + ggml_cuda_topk_moe_args args; + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + std::vector ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), + { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } + + if (args.norm) { + ops.insert(ops.end(), + { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } + + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; + + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } + } + } + + //RoPE + view + set-rows + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { + ggml_tensor * rope = cgraph->nodes[i]; + ggml_tensor * set_rows = cgraph->nodes[i + 2]; + + ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); + return 2; + } + + // multi-(add or mul) + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, node->op); + + for (; n_fuse <= 6; ++n_fuse) { + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); + for (int j = 0; j < n_fuse - 1; ++j) { + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); + } + return n_fuse - 1; + } + } + + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + // gate + glu + up + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) || + (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) { + continue; + } + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + // gate + add + glu + up + add + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node); + return 2; + } + + return 0; +} + static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3786,355 +4137,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } - // start of fusion operations - static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion) { - ggml_cuda_topk_moe_args args; - - if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || - cgraph->nodes[i]->op == GGML_OP_ARGSORT) { - const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); - - std::vector ops; - - if (can_fuse) { - const ggml_tensor * logits = node->src[0]; - ggml_tensor * weights = nullptr; - ggml_tensor * ids = nullptr; - const ggml_tensor * bias = nullptr; - const ggml_tensor * clamp = nullptr; - const ggml_tensor * scale = nullptr; - - if (!args.delayed_softmax) { - ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; - int out_nodes[2]; // nodes which can't be elided - - if (args.prob_bias) { - bias = cgraph->nodes[i + 2]->src[1]; - ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }); - out_nodes[0] = i + 4; - ids = cgraph->nodes[i + 4]; - } else { - ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS }); - out_nodes[0] = i + 3; - ids = cgraph->nodes[i + 3]; - } - - if (args.norm) { - ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, - GGML_OP_DIV, GGML_OP_RESHAPE }); - clamp = cgraph->nodes[i + ops.size() - 3]; - } - if (args.scale) { - ops.insert(ops.end(), { GGML_OP_SCALE }); - scale = cgraph->nodes[i + ops.size() - 1]; - } - - weights = cgraph->nodes[i + ops.size() - 1]; - out_nodes[1] = i + ops.size() - 1; - - if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { - ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); - i += ops.size() - 1; - continue; - } - } else if (!args.norm && !args.prob_bias) { - //special case gpt-oss, no norm, no bias. - ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, - GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); - weights = cgraph->nodes[i + 5]; - ids = cgraph->nodes[i + 1]; - const ggml_tensor * softmax = cgraph->nodes[i + 4]; - - int out_nodes[2] = { i + 1, i + 5 }; - if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { - ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); - i += ops.size() - 1; - continue; - } - } - } - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { - ggml_tensor * rope = cgraph->nodes[i]; - ggml_tensor * set_rows = cgraph->nodes[i + 2]; - - ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); - i += 2; - continue; - } - - if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { - int n_fuse = 0; - ggml_op ops[8]; - std::fill(ops, ops + 8, node->op); - - for (; n_fuse <= 6; ++n_fuse){ - if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { - break; - } - if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { - break; - } - if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { - break; - } - } - - n_fuse++; + int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i); - if (n_fuse > 1) { - ggml_tensor fused_node; - memcpy(&fused_node, node, sizeof(ggml_tensor)); - for (int j = 0; j < n_fuse - 1; ++j) { - fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; - } - fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; - if (node->op == GGML_OP_ADD) { - ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); - } else { - ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); - } - i += n_fuse - 1; - - continue; - } - } - - bool fused_mul_mat_vec = false; - int fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 4]; - ggml_tensor * gate_bias_n = glu->src[0]; - ggml_tensor * up_bias_n = glu->src[1]; - - //we don't assume the order for {gate, up}. Instead infer it from the bias tensor - ggml_tensor * gate_n = nullptr; - ggml_tensor * up_n = nullptr; - - if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { - gate_n = cgraph->nodes[i]; - up_n = cgraph->nodes[i + 2]; - } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { - gate_n = cgraph->nodes[i + 2]; - up_n = cgraph->nodes[i]; - } else { - continue; - } - - auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { - if (op_bias == GGML_OP_ADD) { - if (bias_node->src[0] == mul_node) { - return bias_node->src[1]; - } - if (bias_node->src[1] == mul_node) { - return bias_node->src[0]; - } - return (ggml_tensor *) nullptr; - } - GGML_ASSERT(op_bias == GGML_OP_ADD_ID); - GGML_ASSERT(bias_node->src[0] == mul_node); - return bias_node->src[1]; - }; - - ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); - ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); - - if (!up_bias_tensor || !gate_bias_tensor) { - continue; - } - - // we don't support repeating adds - if (bias_op == GGML_OP_ADD && - (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || - !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { - continue; - } - - const ggml_tensor * src0 = up_n->src[0]; - const ggml_tensor * src1 = up_n->src[1]; - const ggml_tensor * ids = up_n->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 2]; - ggml_tensor * gate = glu->src[0]; - ggml_tensor * up = glu->src[1]; - - bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) - || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); - - if (!ok) continue; - - const ggml_tensor * src0 = up->src[0]; - const ggml_tensor * src1 = up->src[1]; - const ggml_tensor * ids = up->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - fused_mul_mat_vec = false; - fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { - continue; - } - - ggml_tensor * mm_node = cgraph->nodes[i]; - ggml_tensor * bias_node = cgraph->nodes[i + 1]; - - ggml_tensor * bias_tensor = nullptr; - if (bias_op == GGML_OP_ADD) { - if (bias_node->src[0] == mm_node) { - bias_tensor = bias_node->src[1]; - } else if (bias_node->src[1] == mm_node) { - bias_tensor = bias_node->src[0]; - } else { - continue; - } - } else { - if (bias_node->src[0] != mm_node) { - continue; - } - bias_tensor = bias_node->src[1]; - } - - const ggml_tensor * src0 = mm_node->src[0]; - const ggml_tensor * src1 = mm_node->src[1]; - const ggml_tensor * ids = mm_node->src[2]; - - if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { - continue; - } - - if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { - continue; - } - - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.x_bias = bias_tensor; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || - ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || - ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { - ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { - ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - i += 2; - ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); - continue; - } + if (nodes_to_skip != 0) { + i += nodes_to_skip; + continue; } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); From ad670182d95023221b71a0852adf245e7b73cd1c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Apr 2026 16:41:45 +0300 Subject: [PATCH 221/249] ggml : bump version to 0.10.1 (ggml/1469) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index b9f7deb150d..f7b6f1f334f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 10) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 320c048724d0c6e393540ff6ac51eec23afea04c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Apr 2026 21:44:28 +0300 Subject: [PATCH 222/249] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 58863dc6bbb..236ae95a80f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -404fcb9d7c96989569e68c9e7881ee3465a05c50 +387fa29fbbf3149f06a631c7850b6c35c24b0232 From c59a7736051d497d9370db54f01c46845e6bb8ad Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 May 2026 11:53:27 +0300 Subject: [PATCH 223/249] examples : update to Q1_0 --- examples/common-ggml.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 6f02a2504c5..3f2eded86f7 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -74,6 +74,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_BF16: case GGML_FTYPE_MOSTLY_MXFP4: case GGML_FTYPE_MOSTLY_NVFP4: + case GGML_FTYPE_MOSTLY_Q1_0: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -215,6 +216,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: + case GGML_TYPE_Q1_0: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); From 9f2cec1840b38f510b0098fc767a622c40b8a433 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Wed, 29 Apr 2026 16:02:40 +0530 Subject: [PATCH 224/249] ggml-cpu : disable tiled matmul on AIX to fix page boundary segfault (llama/22293) * ggml-cpu : disable tiled matmul on AIX to fix page boundary segfault vec_xst operations in the tiled path crash on AIX when writing near 4KB page boundaries due to strict memory protection. Fall back to mnpack implementation on AIX for stable execution. Signed-off-by: Shalini Salomi Bodapati * Update ggml/src/ggml-cpu/llamafile/sgemm.cpp Co-authored-by: Aaron Teo * Update sgemm.cpp * Update sgemm.cpp --------- Signed-off-by: Shalini Salomi Bodapati Co-authored-by: Aaron Teo --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 34e320e2f50..e13828e3be6 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2321,6 +2321,9 @@ class tinyBLAS_Q0_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else const int64_t mc = 64; const int64_t kc = 64; int64_t nc = 64; @@ -2334,7 +2337,6 @@ class tinyBLAS_Q0_PPC { } else { n_aligned = (n / 64) * 64; } - if (n_aligned > 0) { if (n_aligned % 64 == 0) nc = 64; else if (n_aligned == n) nc = n; @@ -2352,6 +2354,7 @@ class tinyBLAS_Q0_PPC { } else { mnpack(0, m, 0, n); } + #endif } private: @@ -3191,12 +3194,16 @@ class tinyBLAS_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else int64_t mc = 256; int64_t nc = 256; int64_t kc = 256; if (m % mc == 0 && n % nc == 0 && k % kc == 0) { matmul_tiled(m, n, mc, nc, kc); } else { mnpack(0, m, 0, n); } + #endif } private: From aec8e69c2f1f78ea3872361b1483ba99ebf74468 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 29 Apr 2026 11:39:56 -0700 Subject: [PATCH 225/249] CUDA: fuse SSM_CONV + ADD(bias) + SILU (llama/22478) --- ggml/src/ggml-cuda/ggml-cuda.cu | 35 ++++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/ssm-conv.cu | 34 ++++++++++++++++++++++++++------ ggml/src/ggml-cuda/ssm-conv.cuh | 2 +- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fd8dd91714c..0e6f74685d6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; @@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD + && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * add = cgraph->nodes[node_idx+1]; + const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } + + if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias. + const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0]; + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) { + return false; + } + + return true; + } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { const ggml_tensor * unary = cgraph->nodes[node_idx]; @@ -3966,8 +3994,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 1; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]); + ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]); return 1; } diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index b77cdc1c137..4841389fbc8 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,6 +3,7 @@ template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + // Compute from shared memory for (int64_t i = 0; i < local_n_t; i++) { float sumf = 0.0f; @@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, for (size_t j = 0; j < d_conv; j++) { sumf += smem[tid * n_cols + i + j] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template -static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { @@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); ssm_conv_long_token_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_bias = bias_add_node != nullptr; const bool fuse_silu = silu_dst != nullptr; + // bias always comes with silu. + GGML_ASSERT(!fuse_bias || fuse_silu); + + // The bias (when fused) is the non-conv operand of the ADD node. + const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; + // When fusing, write to silu_dst (the node downstream references). const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; @@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; + const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_bias) { + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(bias)); + GGML_ASSERT(ggml_nelements(bias) == nr); + } + if (fuse_silu) { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } else { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index f96a1cd2484..8514ca84920 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); From 66392cf1a2624fb2688d10c835bdc178f669460b Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 29 Apr 2026 11:51:21 -0700 Subject: [PATCH 226/249] hexagon: make vmem and buffer-size configurable (llama/22487) * hexagon: allow host to set max vmem size We use a sane default but it's helpful to allow for an override if needed. * hexagon: add support for measuring vmem space and move pinned mmaping management to host * hexagon: update vmem checks to use uint64 * hexagon: bump op buffers to 16 (matches max mmaps) * hexagon: bump default vmem to 3.2GB * hexagon: add support for autodetecting vmem space and some logging cleanup in that area * hexagon: fix whitespace warnings * Update scripts/snapdragon/adb/run-cli.sh Co-authored-by: Pascal * hex-adb: fix run-completion script --------- Co-authored-by: Pascal --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 238 ++++++++++++++---------- ggml/src/ggml-hexagon/htp/htp-ctx.h | 4 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 8 +- ggml/src/ggml-hexagon/htp/htp_iface.idl | 4 +- ggml/src/ggml-hexagon/htp/main.c | 27 ++- 5 files changed, 162 insertions(+), 119 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 0d9b5e289bb..9345da62168 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -48,14 +48,16 @@ using intvec = std::vector; using uintvec = std::vector; using u32vec = std::vector; -static size_t opt_ndev = 1; -static size_t opt_nhvx = 0; // use all -static int opt_arch = 0; // autodetect -static int opt_etm = 0; -static int opt_verbose = 0; -static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) -static int opt_hostbuf = 1; // hostbuf ON by default -static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static int opt_arch = 0; // autodetect +static size_t opt_ndev = 1; +static size_t opt_nhvx = 0; // use all +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings +static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size +static int opt_etm = 0; +static int opt_verbose = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) +static int opt_hostbuf = 1; // hostbuf ON by default // Default PMU events, if profiling with PMU (mode=2) is enabled // See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html @@ -66,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches + static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ @@ -110,7 +113,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct if (!opt_verbose) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } @@ -118,8 +121,6 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; - op_desc desc(op); - char pmu_str[256] = ""; if (opt_profile > 1) { static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); @@ -127,6 +128,7 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } + op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } @@ -191,33 +193,30 @@ struct ggml_hexagon_shared_buffer { bool mapped; bool pinned; - void mmap(bool pinned = false) { - int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED); + void mmap() { + fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED; + + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags); if (err != 0) { GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), sess->domain_id, this->size, this->fd, (unsigned) err); throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - if (pinned) { - err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), - sess->domain_id, this->size, this->fd, (unsigned) err); - throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)"); - } - } - - this->mapped = true; - this->pinned = pinned; HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", sess->c_name(), (void *) this->base, this->size, this->fd, pinned); + + this->mapped = true; } void unmap() { if (!this->mapped) return; - htp_iface_munmap(sess->handle, this->fd); + if (!this->pinned) { + // HTP might still hold a reference, tell it drop it + htp_iface_munmap(sess->handle, this->fd); + } + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), @@ -227,7 +226,7 @@ struct ggml_hexagon_shared_buffer { this->fd = -1; } - void alloc(size_t size, bool pinned = false) { + void alloc(size_t size) { if (this->base) return; this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); @@ -245,8 +244,7 @@ struct ggml_hexagon_shared_buffer { HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), (void *) this->base, this->size, this->fd, (int) pinned); - - mmap(pinned); + mmap(); } void free() { @@ -262,15 +260,14 @@ struct ggml_hexagon_shared_buffer { } ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { - size += 4 * 1024; // extra page for padding - this->sess = sess; this->size = 0; this->base = nullptr; this->fd = -1; this->mapped = false; + this->pinned = pinned; - alloc(size, pinned); + alloc(size); } ~ggml_hexagon_shared_buffer() { @@ -1475,6 +1472,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { + size += 4 * 1024; // guard page ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { @@ -1487,6 +1485,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { + size += 4 * 1024; // guard page ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { @@ -1505,7 +1504,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1UL * 1024 * 1024 * 1024; // 1GB per buffer + return opt_mbuf; // typically 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1573,14 +1572,14 @@ struct ggml_hexagon_opbatch { d_map.clear(); } - ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) { + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) { this->sess = sess; n_bufs_max = HTP_OP_MAX_BUFS; n_ops_max = batch_size; n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; - b_vmem_max = HTP_OP_MAX_VMEM; + b_vmem_max = max_vmem; ops.resize(n_ops_max); @@ -1592,6 +1591,9 @@ struct ggml_hexagon_opbatch { t_map.reserve(n_tens_max); d_map.reserve(n_tens_max); + GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n", + sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max); + reset(); } @@ -1925,6 +1927,8 @@ void ggml_hexagon_session::flush_batch() { // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc + HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); if (err != 0) { GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); @@ -1944,6 +1948,35 @@ void ggml_hexagon_session::flush(bool all) { flush_pending(all); } +static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) { + // Allocate a bunch pinned buffers till failure. + // This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device. + // Typically we're going to allocate all/most of these buffers anyway for the model weights. + + std::vector sbufs; + + const size_t MiB = 1024 * 1024; + const size_t GiB = MiB * 1024; + + size_t vmem = 0; + size_t step = 256u * MiB; + + try { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + + while (1) { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true)); + vmem += step; + } + } catch (...) { } + + for (auto b : sbufs) { delete b; } + + return vmem - step; // backoff to account for overhead from internal mappings +} + void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_session = false; this->valid_handle = false; @@ -1957,7 +1990,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->op_pending = 0; - GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); + GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str()); domain * my_domain = get_domain(this->domain_id); if (my_domain == NULL) { @@ -2033,9 +2066,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; - GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), - this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); - // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -2047,6 +2077,9 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } } + GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(), + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; @@ -2091,13 +2124,19 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } // Allocate buffers and state for op batching - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); - // Start processing op batch requests - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); + if (!opt_vmem) { + opt_vmem = ggml_hexagon_measure_max_vmem(this); + GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); + } + + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + + // Start dspqueue/opbatch processing + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); + GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; @@ -2108,17 +2147,17 @@ void ggml_hexagon_session::release() noexcept(true) { int err; - delete this->op_batch; - delete this->op_queue; - - // Stop the DSP-side service and close the queue if (this->valid_iface) { + // Stop dspqueue/opbatch processing err = htp_iface_stop(this->handle); if (err != 0) { GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); } } + delete this->op_batch; + delete this->op_queue; + if (opt_etm) { err = htp_iface_etm(this->handle, 0); if (err != 0) { @@ -3380,21 +3419,6 @@ struct ggml_hexagon_registry { ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev); - if (!opt_arch) { - int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); - opt_arch = 73; - } - } - -#if defined(__ANDROID__) - if (opt_arch < 75) { - opt_ndev = 1; - GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); - } -#endif - GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); // Create devices / sessions @@ -3480,32 +3504,67 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, "please update hexagon_type to match ggml_type"); - const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); - const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); - const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); - const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); - const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); - const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); - const char * str_etm = getenv("GGML_HEXAGON_ETM"); - const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); - const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); - const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); - const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); + const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); + const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); + const char * str_etm = getenv("GGML_HEXAGON_ETM"); + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); + const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + + // Init Arch first since it affects other defaults + if (!str_arch) { + int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); + opt_arch = 73; + } + } else { + if (str_arch[0] == 'v' || str_arch[0] == 'V') { + str_arch++; + } + opt_arch = strtoul(str_arch, NULL, 0); + } + + size_t MiB = 1024 * 1024; + + // Update vmem default + opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB; auto RE_ICASE = std::regex_constants::icase; - opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; - opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; - opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_profile = str_profile ? atoi(str_profile) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; + opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem; + + if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { + opt_ndev = GGML_HEXAGON_MAX_SESSIONS; + } + +#if defined(__ANDROID__) + if (opt_arch < 75) { + opt_ndev = 1; + GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); + } +#endif if (str_profile) { opt_pmu_evt = [&]() -> std::vector { @@ -3520,17 +3579,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { vec_to_str(opt_pmu_evt).c_str()); } - if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { - opt_ndev = GGML_HEXAGON_MAX_SESSIONS; - } - - if (str_arch) { - if (str_arch[0] == 'v') { - str_arch++; - } - opt_arch = strtoul(str_arch, NULL, 0); - } - reg->context = new ggml_hexagon_registry(reg); } diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index d704fedee9d..e9c563ca887 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -20,7 +20,7 @@ struct htp_mmap { uint64_t size; uint64_t base; uint32_t fd; - uint32_t pinned; + uint32_t reserved; }; // Scratchpad state @@ -77,6 +77,8 @@ struct htp_context { atomic_bool vtcm_valid; atomic_bool vtcm_needs_release; + uint64_t max_vmem; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 4397245c5b8..66a3150c1a0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -90,15 +90,11 @@ enum htp_op_code { #define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS #define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS -#define HTP_OP_MAX_BUFS 8 +#define HTP_OP_MAX_BUFS 16 #define HTP_OP_MAX_REQS 256 #define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) -#if __HVX_ARCH__ < 75 -#define HTP_OP_MAX_VMEM (3167538380u) -#else -#define HTP_OP_MAX_VMEM (3221225472u) -#endif +#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u) #define HTP_MMAP_MAX_VMEM (2147483648u) diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index dbcafd1d856..d696a5fba0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -11,9 +11,9 @@ struct htp_iface_pmu_conf { }; interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); AEEResult stop(); - AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); + AEEResult mmap(in uint32 fd, in uint32 size); AEEResult munmap(in uint32 fd); AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); AEEResult etm(in uint32 enable); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index f58347304be..49c1a15b344 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -210,7 +210,7 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_SUCCESS; } -AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) { +AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -220,7 +220,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 for (uint32_t i=0; immap[i]; if (m->fd == fd) { - m->pinned = pinned; return AEE_SUCCESS; } } @@ -229,7 +228,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 for (uint32_t i=0; immap[i]; if (!m->size) { - FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); + FARF(HIGH, "mmap : fd %u size %u", fd, size); #if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); #else @@ -248,7 +247,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 m->base = (uint64_t) va; m->fd = fd; m->size = size; - m->pinned = pinned; return AEE_SUCCESS; } @@ -275,7 +273,6 @@ AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { m->size = 0; m->base = NULL; m->fd = -1; - m->pinned = 0; } } @@ -358,7 +355,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -376,12 +373,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que htp_error_callback, // Error callback; no errors expected on the DSP (void *) ctx, // Callback context &ctx->queue); - if (err) { FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); return err; } + ctx->max_vmem = max_vmem; ctx->thread_id = qurt_thread_get_id(); ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); @@ -622,8 +619,8 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct } static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { - if (m->size && !m->pinned) { - FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + if (m->size) { + FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); #if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); #else @@ -660,9 +657,8 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { m->base = b->base = (uint64_t) va; m->fd = b->fd; m->size = b->size; - m->pinned = 0; - FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); return; } } @@ -672,8 +668,8 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) uint32_t b_reuse = 0; // buf reuse count - size_t m_vmem = 0; // mapped vmem - size_t e_vmem = 0; // extra vmem + uint64_t m_vmem = 0; // mapped vmem + uint64_t e_vmem = 0; // extra vmem // See what we can reuse for (uint32_t i=0; i < n_bufs; i++) { @@ -687,9 +683,10 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin // See how much vmem we have mmaped right now for (uint32_t i=0; immap[i].size; } - FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse); + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u", + (size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse); - if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) { + if ((m_vmem + e_vmem) > ctx->max_vmem) { // Drop unused mappings for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { bool used = m_reuse & (1< Date: Wed, 29 Apr 2026 22:58:32 -0700 Subject: [PATCH 227/249] add fast matmul iquants (llama/22504) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 19 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/mul_mat_decls.tmpl | 423 ++++++++++++++++++ 3 files changed, 443 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index b7771ac230e..5239164cd00 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1806,6 +1806,25 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + variant += std::string("_") + src0_name; break; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7fd73ae144..5e55a2a1e1b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1422,7 +1422,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: - use_fast = is_vec; + use_fast = true; break; default: break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 15b22c4f731..51cf08f196f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -740,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_Q6_K + +#ifdef INIT_SRC0_SHMEM_IQ4_NL +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + + let pos = k_in_block % 16u; + let nib_shift = (k_in_block / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); + let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_NL + +#ifdef INIT_SRC0_SHMEM_IQ4_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 136u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let d_scales_h = load_u32_at_src0(block_byte_base); + let d = bitcast>(d_scales_h).x; + let scales_h = d_scales_h >> 16u; + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; + let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); + + let iqs = ib * 16u + (pos % 16u); + let nib_shift = (pos / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); + let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_XS + +#ifdef INIT_SRC0_SHMEM_IQ1_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 50u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; + let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + + let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_S + +#ifdef INIT_SRC0_SHMEM_IQ1_M +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 56u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let scales0 = load_u32_at_src0(block_byte_base + 48u); + let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scale_packed = ((scales0 >> 12u) & 0xFu) | + ((scales0 >> 24u) & 0x00F0u) | + ((scales1 >> 4u) & 0x0F00u) | + ((scales1 >> 16u) & 0xF000u); + let d = f32(bitcast>(scale_packed).x); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let scales = select(scales0, scales1, ib >= 4u); + let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; + let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; + let dl = d * f32(2u * s_pair + 1u); + + let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); + let qh = qh_word >> (16u * (ib % 2u)); + let qh_nib = (qh >> (4u * l)) & 0xFu; + + let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); + let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + + let ig = idx * 8u; + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_M + +#ifdef INIT_SRC0_SHMEM_IQ2_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 66u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); + let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; + + let ig = get_byte(aux0, l) * 8u; + let is = (aux1 >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XXS + +#ifdef INIT_SRC0_SHMEM_IQ2_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 74u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); + let s = get_byte(scales_word, (ib % 16u) / 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); + let qs_val = qs_word & 0xFFFFu; + let ig = (qs_val & 511u) * 8u; + let is = qs_val >> 9u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XS + +#ifdef INIT_SRC0_SHMEM_IQ2_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 82u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let l = (k_in_block % 32u) / 8u; + let j = k_in_block % 8u; + + let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); + let s = get_byte(scales_word, ib % 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); + let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; + let ig = (get_byte(qs_word, l) | qh_b) * 8u; + + let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_S + +#ifdef INIT_SRC0_SHMEM_IQ3_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 98u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib_pair = k_in_block / 32u; + let in_pair = k_in_block % 32u; + let l = in_pair / 8u; + let in_l = in_pair % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let ib = ib_pair * 2u; + let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; + let sc_sign = load_u32_at_src0(sc_sign_off); + let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; + let is = (sc_sign >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; + let ig_byte = get_byte(ig_word, k2); + let g = get_byte(iq3xxs_grid[ig_byte], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_XXS + +#ifdef INIT_SRC0_SHMEM_IQ3_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 110u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 64u; + let rest = k_in_block % 64u; + let k = rest / 32u; + let in_k = rest % 32u; + let l = in_k / 8u; + let in_l = in_k % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let scales_word = load_u32_at_src0(block_byte_base + 106u); + let s = get_byte(scales_word, ib); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); + let dl = d * (1.0 + 2.0 * f32(s_nib)); + + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); + let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; + let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); + let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); + let ig = select(ig_lo, ig_hi, k2 != 0u); + + let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq3s_grid[ig], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_S From 582d2562a41f89388e5040253100780b3934c7c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 30 Apr 2026 13:04:50 +0200 Subject: [PATCH 228/249] CUDA: fix tile FA kernel on Pascal (llama/22541) --- ggml/src/ggml-cuda/fattn-tile.cuh | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 928b856f9d2..585f2c22853 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) @@ -130,7 +130,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) @@ -1124,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr size_t nbytes_shared = 0; #ifdef GGML_USE_HIP - if constexpr (DV <= 128) { + if constexpr (DKQ <= 128) { if (Q->ne[1] > 32/ncols2) { constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; @@ -1138,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm #endif // GGML_USE_HIP #ifndef GGML_USE_HIP - if constexpr (DV <= 256) + if constexpr (DKQ <= 256) #endif // GGML_USE_HIP { if (Q->ne[1] > 16/ncols2) { @@ -1220,11 +1220,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DKQ == 320) { // Mistral Small 4 + if constexpr (DKQ == 320) { + // This branch is only used for Mistral Small 4 which has a GQA ratio of 32. + // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM. + // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older). + // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block. +#ifdef GGML_USE_HIP if (use_gqa_opt && gqa_ratio % 32 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; } +#else + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } +#endif // GGML_USE_HIP GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); } From 0c7c3ba570cb0b6f03da762d53ba211022cfb89a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 30 Apr 2026 17:37:13 +0200 Subject: [PATCH 229/249] vulkan: add get/set tensor 2d functions (llama/22514) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * vulkan: add get/set_tensor_2d functions * fix backend interface comments * Update ggml/src/ggml-metal/ggml-metal.cpp Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-backend-meta.cpp | 2 +- ggml/src/ggml-blas/ggml-blas.cpp | 4 +- ggml/src/ggml-cann/ggml-cann.cpp | 2 +- ggml/src/ggml-cpu/ggml-cpu.cpp | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.cpp | 6 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 +- ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 +- ggml/src/ggml-virtgpu/ggml-backend.cpp | 2 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 213 +++++++++++++++++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- ggml/src/ggml-zdnn/ggml-zdnn.cpp | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 2 +- 15 files changed, 181 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index fbc02d6458a..c0ffd9a048b 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -2100,8 +2100,8 @@ static const ggml_backend_i ggml_backend_meta_i = { /* .free = */ ggml_backend_meta_free, /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, - /* .get_tensor_2d_async = */ nullptr, /* .set_tensor_2d_async = */ nullptr, + /* .get_tensor_2d_async = */ nullptr, /* .cpy_tensor_async = */ nullptr, /* .synchronize = */ ggml_backend_meta_synchronize, /* .graph_plan_create = */ nullptr, diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 05245b69807..b4c735267e0 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -262,9 +262,9 @@ static struct ggml_backend_i blas_backend_i = { /* .get_name = */ ggml_backend_blas_get_name, /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, - /* .set_tensor_2d_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3618ba7f6f6..5f51ea3bb3c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2746,8 +2746,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 49f840be207..128883b41ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -195,8 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0e6f74685d6..fbe0fa06242 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4588,8 +4588,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, - /* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, - /* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9345da62168..17ac083f4ea 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -3036,8 +3036,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 6a836e45908..cc329d67594 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -166,8 +166,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .get_tensor_2d_async = */ NULL, - /* .set_tensor_2d_async = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, /* .clear = */ ggml_backend_metal_buffer_private_clear, /* .reset = */ NULL, @@ -567,8 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4d31591a4a6..11f72a5198a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4343,9 +4343,9 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .free = */ ggml_backend_opencl_free, /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ - /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 505bec73d37..7176d2feef9 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -740,9 +740,9 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .free = */ ggml_backend_rpc_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 1eead625e76..f06147eeeb8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4700,8 +4700,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index 2b978556228..12756c9282f 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -34,8 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = { /* .free = */ ggml_backend_remoting_free, /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 69c24bb5877..10b73317943 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6845,7 +6845,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { +static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); // Check if src is pinned memory vk_buffer buf = nullptr; @@ -6855,7 +6855,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz if (buf != nullptr) { // Memory is pinned, use as staging buffer std::vector slices(1); - if (width == spitch) { + if (width == spitch && width == dpitch) { // Only do single write if stride is equal slices[0].srcOffset = buf_offset; slices[0].dstOffset = offset; @@ -6864,7 +6864,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz slices.resize(height); for (size_t i = 0; i < height; i++) { slices[i].srcOffset = buf_offset + i * spitch; - slices[i].dstOffset = offset + i * width; + slices[i].dstOffset = offset + i * dpitch; slices[i].size = width; } } @@ -6881,21 +6881,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } // Staging buffer required - const size_t copy_size = width*height; - ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size); vk_buffer& staging_buffer = dst->device->sync_staging; - VkBufferCopy buf_copy = { - 0, - offset, - copy_size}; + std::vector slices(1); + if (width == dpitch) { + slices[0].srcOffset = 0; + slices[0].dstOffset = offset; + slices[0].size = staging_size; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = i * width; + slices[i].dstOffset = offset + i * dpitch; + slices[i].size = width; + } + } ggml_vk_sync_buffers(nullptr, subctx); - vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices); if (width == spitch) { - deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys); } else { for (size_t i = 0; i < height; i++) { deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); @@ -6906,24 +6915,24 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); - return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); // Buffer is already mapped if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); } } else { std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); @@ -6944,7 +6953,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); - ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); + ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1); } static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { @@ -6990,15 +6999,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size } // Fall back to staging buffer - const size_t copy_size = dpitch * height; - ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(src->device, staging_size); vk_buffer& staging_buffer = src->device->sync_staging; + std::vector staging_slices(1); + if (width == spitch) { + staging_slices[0].srcOffset = offset; + staging_slices[0].dstOffset = 0; + staging_slices[0].size = staging_size; + } else { + staging_slices.resize(height); + for (size_t i = 0; i < height; i++) { + staging_slices[i].srcOffset = offset + i * spitch; + staging_slices[i].dstOffset = i * width; + staging_slices[i].size = width; + } + } + ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices); - deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); + if (width == dpitch) { + deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys); + } + } return true; } @@ -7006,8 +7035,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { - VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); +static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")"); // If the device is not an UMA device the memory is host-accessible through rebar. While writing // through PCIe is sufficient fast reading back data from PCIe is slower than going through @@ -7015,18 +7044,20 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - memcpy(dst, (uint8_t *) src->ptr + offset, size); + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } } else { std::lock_guard guard(src->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); - bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); - VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences"); src->device->device.resetFences({ src->device->fence }); ggml_vk_queue_command_pools_cleanup(src->device); @@ -7036,6 +7067,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ } } +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1); +} + static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); // Make sure both buffers are on same device @@ -7067,7 +7103,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr // Copy to src staging buffer ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); // Copy to dst buffer - ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); + ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size); } } @@ -13615,6 +13651,20 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + if (size == 0) { + return; + } + + ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies); +} + static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -13628,6 +13678,21 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + if (size == 0) { + return; + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies); +} + static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { if (ggml_nbytes(src) == 0) { return true; @@ -13662,8 +13727,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, - /* .set_tensor_2d = */ NULL, - /* .get_tensor_2d = */ NULL, + /* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -13819,8 +13884,9 @@ static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_b return &ctx->device->buffer_type; } -static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); +static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); @@ -13834,7 +13900,6 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor if (ctx->device->async_use_transfer_queue) { if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = cpy_ctx; ggml_vk_ctx_begin(ctx->device, cpy_ctx); @@ -13849,25 +13914,48 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies); if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); ggml_vk_sync_buffers(nullptr, cpy_ctx); - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = 0; - buffer_cpy.dstOffset = dst_offset; - buffer_cpy.size = size; + std::vector slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = 0; + slices[0].dstOffset = dst_offset; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = i * size; + slices[i].dstOffset = dst_offset + i * stride_tensor; + slices[i].size = size; + } + } - cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys); + cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices); + + if (size == stride_data) { + deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys); + } + } ggml_vk_synchronize(ctx); } } -static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + +static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); @@ -13882,24 +13970,45 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies); - // If that failed, copy synchronously through a staging buffer if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); ggml_vk_sync_buffers(nullptr, compute_ctx); - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = src_offset; - buffer_cpy.dstOffset = 0; - buffer_cpy.size = size; + std::vector slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = src_offset; + slices[0].dstOffset = 0; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = src_offset + i * stride_tensor; + slices[i].dstOffset = i * size; + slices[i].size = size; + } + } + + compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices); - compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); + if (size == stride_data) { + deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys); + } + } ggml_vk_synchronize(ctx); } } +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; @@ -15123,8 +15232,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, - /* .set_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5e55a2a1e1b..a1dccfc0f5a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3107,8 +3107,8 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .free = */ ggml_backend_webgpu_free, /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index e6b6fc24fd7..639b818d128 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -423,8 +423,8 @@ static ggml_backend_i ggml_backend_zdnn_i = { /* .free = */ ggml_backend_zdnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fc1df4dbef4..2b82c7c1dbb 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -407,8 +407,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, From b34a9f3d83e8443835ad42778885d3b5ec8b825a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 1 May 2026 06:19:10 +0900 Subject: [PATCH 230/249] ggml-webgpu: Improve performance of mat-vec and mat-mat for MUL_MAT_ID (llama/22464) * Add mat-vec fast path of MUL_MAT_ID. * Add shared accumulation vec logic and the other types supports. * Add i-quant mat-mat for MUL_MAT_ID and fix some parts * Remove n_experts from shader_lib_context. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 173 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 73 +- .../wgsl-shaders/mul_mat_id_vec.wgsl | 154 ++ .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 1284 +-------------- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 1391 +++++++++++++++++ 5 files changed, 1780 insertions(+), 1295 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 5239164cd00..0f66275c6a3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -664,7 +664,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ } const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -701,10 +701,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && @@ -862,9 +862,12 @@ struct ggml_webgpu_mul_mat_shader_decisions { struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; + uint32_t n_experts; + int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type; + return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && + vectorized == other.vectorized; } }; @@ -873,6 +876,8 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } }; @@ -1023,6 +1028,8 @@ class ggml_webgpu_shader_lib { std::unordered_map mul_mat_id_gather_pipelines; // key is fixed std::unordered_map mul_mat_id_pipelines; // src0_type/src1_type + std::unordered_map + mul_mat_id_vec_pipelines; // src0_type/src1_type std::unordered_map set_rows_pipelines; @@ -1516,7 +1523,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1633,10 +1640,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -2012,6 +2019,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_id_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2041,14 +2053,12 @@ class ggml_webgpu_shader_lib { switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); - defines.push_back("FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); - defines.push_back("FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f16"; @@ -2064,12 +2074,32 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + variant += std::string("_") + src0_name; break; } } - defines.push_back("SCALAR"); + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); // mul_mat_id is register-tile only. const uint32_t tile_k = @@ -2102,6 +2132,123 @@ class ggml_webgpu_shader_lib { return mul_mat_id_pipelines[key]; } + webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + + auto it = mul_mat_id_vec_pipelines.find(key); + if (it != mul_mat_id_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_id_vec"; + const char * shader_src = wgsl_mul_mat_id_vec; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } + + defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->outputs_per_wg = outputs_per_wg; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_vec_pipelines[key] = pipeline; + return mul_mat_id_vec_pipelines[key]; + } + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a1dccfc0f5a..f102c7a818b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1404,7 +1404,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: @@ -1527,11 +1526,74 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_mul_mat_id_vec_pipeline(shader_lib_ctx); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + std::vector entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * param_n_expert_used; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { + // we can use mat-vec fast path + if (dst->ne[2] == 1) { + return ggml_webgpu_mul_mat_id_vec(ctx, src0, src1, src2, dst); + } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; @@ -3879,6 +3941,15 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: supports_op = true; break; default: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl new file mode 100644 index 00000000000..6ff9bcf2df0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -0,0 +1,154 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_vec_acc.tmpl" + +struct MulMatIdVecParams { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] +@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens(1)] +@group(0) @binding(2) var ids: array; // [n_experd_used, n_tokens(1)] +@group(0) @binding(3) var dst: array; // [rows, n_expert_used, n_tokens(1)] + +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 +@group(0) @binding(4) var params: MulMatIdVecParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} + +var gathered_count_ids: array; +var gathered_expert_used: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 +#endif +) { + + let thread_id = local_id.x; + + for (var i = thread_id;i < params.n_expert;i += WG_SIZE) { + gathered_count_ids[i] = 0; + } + + workgroupBarrier(); + + // gather the selected experts for the target token. + for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) { + let expert = ids[params.offset_ids + col]; + gathered_count_ids[expert] = 1; + gathered_expert_used[expert] = col; + } + + workgroupBarrier(); + + let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + var own_expert:u32 = 0; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_vec_count = gathered_count_ids[i]; // 1 or 0 + let wg_per_matrix = output_groups * wg_vec_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + own_expert = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + let dst1_stride = params.m; + + let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02; + let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11; + let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base; + + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); + +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } + + workgroupBarrier(); + + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif + +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; + } + + workgroupBarrier(); + + var stride:u32 = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index a8000439bfb..a194cf40468 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -6,38 +6,7 @@ enable f16; #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" -#ifdef U32_DEQUANT_HELPERS -#define SRC0_TYPE u32 - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} - -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} -#endif - -#ifdef VEC -#define VEC_SIZE 4u -#define SRC0_TYPE vec4 -#define SRC1_TYPE vec4 - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(dot(SRC1_TYPE(src0_val), src1_val)); -} -#endif - -#ifdef SCALAR -#define VEC_SIZE 1u -#define SRC0_TYPE SRC0_INNER_TYPE -#define SRC1_TYPE SRC1_INNER_TYPE - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(src0_val) * f32(src1_val); -} -#endif +#include "mul_mat_vec_acc.tmpl" struct MulMatParams { offset_src0: u32, @@ -62,6 +31,7 @@ struct MulMatParams { @group(0) @binding(1) var src1: array; @group(0) @binding(2) var dst: array; +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var params: MulMatParams; // Flattened as [row][thread] to keep each row's reduction contiguous in memory. @@ -108,1255 +78,7 @@ fn main( let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; - var acc: array; - -#ifdef MUL_ACC_FLOAT - let k_vec = params.k / VEC_SIZE; - let src1_idx_base_vec = src1_idx_base / VEC_SIZE; - - // Each thread walks K, loads from the vector, and updates - // a small block of output rows held in registers. - for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); - } - } - } -#endif - -#ifdef MUL_ACC_Q1_0 -#define BLOCK_SIZE 128 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 16 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q4_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % 4; - for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q4_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 20 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q5_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 22 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qh_packed = load_u32_at_src0(block_byte_base + 2u); - let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q5_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 24 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q8_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 34 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q8_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 36 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q2_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 84 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let iq = lane / 4u; - let ir = lane % 4u; - let is = ir / 2u; - - let y_offset = 128u * iq + 8u * ir + 4u * phase; - let sc0_byte = 8u * iq + is; - let sc2_byte = 8u * iq + is + 2u; - let sc4_byte = 8u * iq + is + 4u; - let sc6_byte = 8u * iq + is + 6u; - let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let dall = f32(load_f16_at_src0(block_byte_base + 80u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); - - let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); - let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); - let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); - let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); - - let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); - let qs0 = q_u32 & 0xFFFFu; - let qs1 = q_u32 >> 16u; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; - - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); - - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); - } - } - } -#endif - - -#ifdef MUL_ACC_Q3_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let ip = lane / 4u; - let il = 2u * ((lane % 4u) / 2u); - let ir = lane % 2u; - let l0 = 8u * ir; - - let q_byte = 32u + 32u * ip + l0 + 16u * phase; - let h_byte = l0 + 16u * phase; - let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - - let s_shift1 = 4u * ip; - let s_shift2 = s_shift1 + il; - - let v1 = select(64.0, 4.0, il == 0u); - let v2 = 4.0 * v1; - let shift = 2u * il; - - var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; - if (il == 0u) { - qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; - } else { - qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; - } - - let mm_idx = 2u * ip + il / 2u; - var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; - switch (mm_idx) { - case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } - case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } - case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } - default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } - } - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 108u)); - let a_base = 96u; - let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); - let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); - let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); - let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); - - var scales32 = a_4 | (a_5 << 16u); - let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; - scales32 = a_il0 | (a_il1 << 16u); - scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; - - let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); - let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); - - let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); - let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); - let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); - let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); - } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); - } - } - } -#endif - -#ifdef MUL_ACC_Q4_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 144 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 32u * im + l0; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 0u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); - - let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let scale0 = f32(sc16_0 & 0xFFu); - let scale1 = f32((sc16_0 >> 8u) & 0xFFu); - let min0 = f32(sc16_1 & 0xFFu); - let min1 = f32((sc16_1 >> 8u) & 0xFFu); - let scale2 = f32(sc16_2 & 0xFFu); - let scale3 = f32((sc16_2 >> 8u) & 0xFFu); - let min2 = f32(sc16_3 & 0xFFu); - let min3 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); - let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } - - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); - } - } - } -#endif - -#ifdef MUL_ACC_Q5_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 176 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 48u + 32u * im + l0; - let qh_offset = 16u + 8u * ir + 4u * in; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let hm1 = 1u << (2u * im); - let hm2 = hm1 << 1u; - let hm3 = hm1 << 4u; - let hm4 = hm2 << 4u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 0u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); - - let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let f0 = f32(sc16_0 & 0xFFu); - let f1 = f32((sc16_0 >> 8u) & 0xFFu); - let m0 = f32(sc16_1 & 0xFFu); - let m1 = f32((sc16_1 >> 8u) & 0xFFu); - let f4 = f32(sc16_2 & 0xFFu); - let f5 = f32((sc16_2 >> 8u) & 0xFFu); - let m4 = f32(sc16_3 & 0xFFu); - let m5 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); - let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); - let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); - - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; - - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; - - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; - } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); - } - } - } -#endif - -#ifdef MUL_ACC_Q6_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 210 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; - - let num_blocks = params.k / BLOCK_SIZE; - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 208u)); - let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); - let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); - let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); - let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); - let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); - - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - - var sums = vec4(0.0, 0.0, 0.0, 0.0); - - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; - } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); - } - } - } -#endif - -#ifdef MUL_ACC_IQ1_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 50 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base)); - let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; - let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); - let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ1_M -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 56 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let sc_lo = load_u32_at_src0(block_byte_base + 48u); - let sc_hi = load_u32_at_src0(block_byte_base + 52u); - let sc0 = sc_lo & 0xFFFFu; - let sc1 = (sc_lo >> 16u) & 0xFFFFu; - let sc2 = sc_hi & 0xFFFFu; - let sc3 = (sc_hi >> 16u) & 0xFFFFu; - let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); - let d = f32(bitcast>(d_bits)[0]); - - let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), - select(sc0, sc1, sub_blk >= 2u), - sub_blk < 4u); - - let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); - let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; - let qh_lo = qh & 0xFFu; - let qh_hi = (qh >> 8u) & 0xFFu; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ2_XXS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 66 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let ls = aux_hi >> 28u; - let db = d * (0.5 + f32(ls)) * 0.25; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ2_XS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 74 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let scales_byte = get_byte(scales_word, sub_blk % 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ2_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 82 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let qh_byte = get_byte(qh_word, sub_blk % 4u); - let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); - let scales_byte = get_byte(sc_word, sub_blk % 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ3_XXS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 98 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); - let ls = aux >> 28u; - let db = d * (0.5 + f32(ls)) * 0.5; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ3_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let qh_byte = get_byte(qh_word, sub_blk % 4u); - let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); - let sc_word = load_u32_at_src0(block_byte_base + 106u); - let scales_byte = get_byte(sc_word, sub_blk / 2u); - let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; - let db = d * (1.0 + 2.0 * f32(sub_scale)); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ4_NL -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ4_XS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 136 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let y_offset = sub_blk * 32u + half * 16u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let scales_h = load_u16_at_src0(block_byte_base + 2u); - let scales_l_word = load_u32_at_src0(block_byte_base + 4u); - let sl_byte = get_byte(scales_l_word, sub_blk / 2u); - let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; - let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; - let ls = i32(sl | (sh_bits << 4u)); - let dl = d * f32(ls - 32); - - let qs_byte_off = 8u + sub_blk * 16u; - let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); - let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); - let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); - let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; - } - acc[row] += row_sum; - } - } - } -#endif + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl new file mode 100644 index 00000000000..1f59bd14863 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -0,0 +1,1391 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#ifdef VEC +#define VEC_SIZE 4u +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); +} +#endif + +#ifdef SCALAR +#define VEC_SIZE 1u +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#endif + +#ifdef MUL_ACC_FLOAT +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q1_0 +#define BLOCK_SIZE 128 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 16 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[bit]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } + + return acc; +} +#endif + + +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; + + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; + + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; + + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; + + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } + + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 50 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base)); + let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; + let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_M +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 56 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let sc_lo = load_u32_at_src0(block_byte_base + 48u); + let sc_hi = load_u32_at_src0(block_byte_base + 52u); + let sc0 = sc_lo & 0xFFFFu; + let sc1 = (sc_lo >> 16u) & 0xFFFFu; + let sc2 = sc_hi & 0xFFFFu; + let sc3 = (sc_hi >> 16u) & 0xFFFFu; + let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); + let d = f32(bitcast>(d_bits)[0]); + + let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), + select(sc0, sc1, sub_blk >= 2u), + sub_blk < 4u); + + let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); + let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; + let qh_lo = qh & 0xFFu; + let qh_hi = (qh >> 8u) & 0xFFu; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 66 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let ls = aux_hi >> 28u; + let db = d * (0.5 + f32(ls)) * 0.25; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 74 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(scales_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 82 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(sc_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 98 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); + let ls = aux >> 28u; + let db = d * (0.5 + f32(ls)) * 0.5; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); + let sc_word = load_u32_at_src0(block_byte_base + 106u); + let scales_byte = get_byte(sc_word, sub_blk / 2u); + let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(sub_scale)); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_NL +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + i + 16u]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 136 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let y_offset = sub_blk * 32u + half * 16u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let scales_h = load_u16_at_src0(block_byte_base + 2u); + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let sl_byte = get_byte(scales_l_word, sub_blk / 2u); + let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; + let ls = i32(sl | (sh_bits << 4u)); + let dl = d * f32(ls - 32); + + let qs_byte_off = 8u + sub_blk * 16u; + let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); + let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); + let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); + let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); + + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif From ccd04522f96ff68cdba1312cca8e7472a4a8bb13 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Fri, 1 May 2026 01:22:18 -0400 Subject: [PATCH 231/249] ggml-webgpu: add the upscale shader (llama/22419) * shader(upscale): add the upscale shader with nearest, bilinear and bicubic implementations * shader(upscale): use macro --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 94 +++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 49 ++++ .../src/ggml-webgpu/wgsl-shaders/upscale.wgsl | 240 ++++++++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 0f66275c6a3..651c9cbcdf6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,6 +1,7 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" @@ -405,6 +406,31 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Upscale **/ + +struct ggml_webgpu_upscale_pipeline_key { + ggml_type input_type; + ggml_type output_type; + uint32_t base_mode; + bool antialias; + + bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode && + antialias == other.antialias; + } +}; + +struct ggml_webgpu_upscale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + ggml_webgpu_hash_combine(seed, key.base_mode); + ggml_webgpu_hash_combine(seed, key.antialias); + return seed; + } +}; + /** Concat **/ struct ggml_webgpu_concat_pipeline_key { @@ -1049,6 +1075,8 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_rms_norm_mul_pipeline_key_hash> rms_norm_mul_pipelines; + std::unordered_map + upscale_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2947,6 +2975,72 @@ class ggml_webgpu_shader_lib { return im2col_pipelines[key]; } + webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0); + const uint32_t base_mode = mode_flags & 0xFFu; + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u; + + ggml_webgpu_upscale_pipeline_key key = {}; + key.input_type = context.src0->type; + key.output_type = context.dst->type; + key.base_mode = base_mode; + key.antialias = antialias; + + auto it = upscale_pipelines.find(key); + if (it != upscale_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "upscale"; + + if (key.input_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } else { + variant += "_src_f32"; + } + + if (key.output_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } else { + variant += "_dst_f32"; + } + + switch (base_mode) { + case GGML_SCALE_MODE_NEAREST: + defines.push_back("NEAREST"); + variant += "_nearest"; + break; + case GGML_SCALE_MODE_BILINEAR: + defines.push_back("BILINEAR"); + variant += "_bilinear"; + break; + case GGML_SCALE_MODE_BICUBIC: + defines.push_back("BICUBIC"); + variant += "_bicubic"; + break; + default: + GGML_ABORT("Unsupported upscale mode"); + } + + if (antialias) { + defines.push_back("ANTIALIAS"); + variant += "_aa"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_upscale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + upscale_pipelines[key] = pipeline; + return upscale_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f102c7a818b..cab0aead198 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2824,6 +2824,49 @@ static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, return true; } +static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0); + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + mode_flags }; + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + // Returns the encoded command, or std::nullopt if the operation is a no-op static std::optional ggml_webgpu_encode(webgpu_context ctx, ggml_cgraph * cgraph, @@ -2931,6 +2974,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_conv_2d(ctx, src0, src1, node); case GGML_OP_IM2COL: return ggml_webgpu_im2col(ctx, src0, src1, node); + case GGML_OP_UPSCALE: + return ggml_webgpu_upscale(ctx, src0, node); default: return std::nullopt; } @@ -4163,6 +4208,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUM_ROWS: supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0); break; + case GGML_OP_UPSCALE: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl new file mode 100644 index 00000000000..e9ef8822644 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl @@ -0,0 +1,240 @@ +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif + +#ifdef SRC_F16 +#define SRC_TYPE f16 +#else +#define SRC_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + src_w: u32, + src_h: u32, + src_z: u32, + src_n: u32, + + dst_w: u32, + dst_h: u32, + dst_z: u32, + dst_n: u32, + + mode_flags: u32, +}; + +@group(0) @binding(2) +var params: Params; + +const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u; + +fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 { + let cx = u32(clamp(x, 0, i32(params.src_w) - 1)); + let cy = u32(clamp(y, 0, i32(params.src_h) - 1)); + let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3; + return f32(input[i]); +} + +fn cubic_weight(t: f32, a: f32) -> f32 { + let at = abs(t); + if (at <= 1.0) { + return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0; + } else if (at <= 2.0) { + return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a; + } else { + return 0.0; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y; + let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n; + + if (i_out >= total) { + return; + } + + // decode (x, y, z, n) + var i = i_out; + let x_dst = i % params.dst_w; + i = i / params.dst_w; + let y_dst = i % params.dst_h; + i = i / params.dst_h; + let z_dst = i % params.dst_z; + let n_dst = i / params.dst_z; + + // scale factors + var sf0 = f32(params.dst_w) / f32(params.src_w); + var sf1 = f32(params.dst_h) / f32(params.src_h); + var sf2 = f32(params.dst_z) / f32(params.src_z); + var sf3 = f32(params.dst_n) / f32(params.src_n); + + let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0; + + // pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners + var pixel_offset = 0.5; + if (align_corners) { + pixel_offset = 0.0; + if (params.dst_w > 1 && params.src_w > 1) { + sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1); + } + if (params.dst_h > 1 && params.src_h > 1) { + sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1); + } + } + + let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2))); + let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3))); + + var result = 0.0; + +#if defined(NEAREST) + + let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0))); + let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1))); + + result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src); + +#elif defined(BILINEAR) + +#if defined(ANTIALIAS) + + // Antialiased bilinear: triangle filter over a variable support region. + let support0 = max(1.0f / sf0, 1.0f); + let support1 = max(1.0f / sf1, 1.0f); + let invscale0 = 1.0 / support0; + let invscale1 = 1.0 / support1; + + let fx = (f32(x_dst) + pixel_offset) / sf0; + let fy = (f32(y_dst) + pixel_offset) / sf1; + + let x_min = max(i32(fx - support0 + pixel_offset), 0); + let y_min = max(i32(fy - support1 + pixel_offset), 0); + let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w)); + let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h)); + + var weighted_sum = 0.0; + var total_weight = 0.0; + + for (var x = x_min; x < x_max; x += 1) { + let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0); + for (var y = y_min; y < y_max; y += 1) { + let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0); + let w = wx * wy; + if (w > 0.0) { + weighted_sum += get_clamped_input(x, y, z_src, n_src) * w; + total_weight += w; + } + } + } + + if (total_weight > 0.0) { + result = weighted_sum / total_weight; + } + +#else + + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = clamp(fx - f32(x0), 0.0, 1.0); + let dy = clamp(fy - f32(y0), 0.0, 1.0); + let a = get_clamped_input(x0, y0, z_src, n_src); + let b = get_clamped_input(x0 + 1, y0, z_src, n_src); + let c = get_clamped_input(x0, y0 + 1, z_src, n_src); + let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + + let wa = (1.0 - dx) * (1.0 - dy); + let wb = dx * (1.0 - dy); + let wc = (1.0 - dx) * dy; + let wd = dx * dy; + + result = a * wa + b * wb + c * wc + d * wd; + +#endif + +#elif defined(BICUBIC) + + // bicubic convolution with alpha = -0.75 (PyTorch default) + let alpha = -0.75; + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = fx - f32(x0); + let dy = fy - f32(y0); + + // horizontal weights for offsets -1, 0, 1, 2 + let wx0 = cubic_weight(dx + 1.0, alpha); + let wx1 = cubic_weight(dx, alpha); + let wx2 = cubic_weight(1.0 - dx, alpha); + let wx3 = cubic_weight(2.0 - dx, alpha); + + // vertical weights for offsets -1, 0, 1, 2 + let wy0 = cubic_weight(dy + 1.0, alpha); + let wy1 = cubic_weight(dy, alpha); + let wy2 = cubic_weight(1.0 - dy, alpha); + let wy3 = cubic_weight(2.0 - dy, alpha); + + // intermediate horizontal interpolation for 4x4 grid of pixels + // x0-1, x0, x0+1, x0+2, y0-1 + let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src); + let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src); + let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src); + let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src); + let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0 + let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src); + let q1 = get_clamped_input(x0, y0, z_src, n_src); + let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src); + let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src); + let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+1 + let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src); + let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src); + let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src); + let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+2 + let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src); + let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src); + let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src); + let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src); + let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3; + + // final vertical interpolation + result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3; + +#endif + + let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3; + output[dst_idx] = DST_TYPE(result); +} From e10025351cc17bf52b94bacb0cd705deec947f8d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 May 2026 13:08:32 +0300 Subject: [PATCH 232/249] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 236ae95a80f..a03455e74c8 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -387fa29fbbf3149f06a631c7850b6c35c24b0232 +b70770970e84c30a007b3859a453768b3ece2d3d From 35cb6841299888d20ad320966ce2176c403ada7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 May 2026 18:53:30 +0300 Subject: [PATCH 233/249] ggml : try fix win32 build (#0) --- ggml/src/ggml.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 54d3eae3e4d..81343eeb14c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -55,8 +55,13 @@ uint64_t ggml_graph_next_uid(void) { #ifdef _MSC_VER +#if defined(_WIN32) + static volatile LONG counter = 1; + return (uint64_t) InterlockedIncrement(&counter) - 1; +#else static volatile long long counter = 1; return (uint64_t) _InterlockedIncrement64(&counter) - 1; +#endif #else static uint64_t counter = 1; return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED); From 95053f68e4c2b638b3b33c200cfd9f4dd96976b7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 1 May 2026 15:28:32 +0200 Subject: [PATCH 234/249] vulkan: Support asymmetric FA in coopmat2 path (llama/21753) * vulkan: Support asymmetric FA in coopmat2 path There has been some recent interest/experimentation with mixed quantization types for FA. I had originally designed the cm2 FA shader with this in mind (because I didn't realize it wasn't supported at the time!), this change adds the missing pieces and enables it. Also support Q1_0 since people have been trying that out (seems crazy, but who knows). We should be able to do similar things in the coopmat1/scalar path, but there's another change open against the scalar path and I don't want to conflict. * reorder cases --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 155 +++++++++++------- .../vulkan-shaders/flash_attn_base.glsl | 6 + .../vulkan-shaders/flash_attn_cm2.comp | 94 ++++++++--- .../vulkan-shaders/vulkan-shaders-gen.cpp | 17 +- 4 files changed, 185 insertions(+), 87 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 10b73317943..c2f1883328f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -440,10 +440,12 @@ struct vk_fa_pipeline_state { bool f32acc; uint32_t flags; uint32_t limit_occupancy_shmem; + ggml_type k_type; + ggml_type v_type; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < - std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type); } }; @@ -3041,7 +3043,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_kv); GGML_UNUSED(f32acc); @@ -3055,7 +3057,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device if (small_rows) { result.block_rows = 32; result.block_cols = 32; - } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + } else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) { result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; result.block_cols = 32; } else { @@ -3069,7 +3071,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. + if (k_type != v_type) { + GGML_ASSERT(device->coopmat2); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + } + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3081,7 +3089,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3094,20 +3102,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ path = FA_SCALAR; } + // Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it. + if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) { + path = FA_COOPMAT2; + } + switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT2: - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: throw std::runtime_error("unsupported FaCodePath"); } } static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, - bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) { const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); @@ -3118,12 +3131,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; - return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type}; } static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { - return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, - state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; + const auto fa_block_bytes = [](ggml_type t) -> uint32_t { + // decodeBufF32 uses a block of vec4s for a better memory access pattern. + return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + }; + return { + /* 0 WorkGroupSize */ state.workgroup_size, + /* 1 Br */ state.Br, + /* 2 Bc */ state.Bc, + /* 3 HSK */ state.HSK, + /* 4 HSV */ state.HSV, + /* 5 Clamp */ static_cast(!state.aligned), + /* 6 D_split */ state.D_split, + /* 7 row_split */ state.row_split, + /* 8 SubGroupSize */ state.subgroup_size, + /* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u, + /*10 Flags */ state.flags, + /*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem, + /*12 FaTypeK */ static_cast(state.k_type), + /*13 FaTypeV */ static_cast(state.v_type), + /*14 FaBlockBytesK */ fa_block_bytes(state.k_type), + /*15 FaBlockBytesV */ fa_block_bytes(state.v_type), + }; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3578,16 +3611,35 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#define CREATE_FA_CM2_MIXED() \ + for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ + for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ + FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FA_COOPMAT2) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } \ + } \ + } \ + } \ + } if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + CREATE_FA_CM2_MIXED(); } +#undef CREATE_FA_CM2_MIXED #endif #undef CREATE_FA @@ -9042,8 +9094,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(dst->type == GGML_TYPE_F32); assert(q->type == GGML_TYPE_F32); - assert(k->type == v->type); - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; @@ -9054,7 +9104,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc); const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && @@ -9067,7 +9117,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); + + if (tuning_params.path != FA_COOPMAT2) { + GGML_ASSERT(k->type == v->type); + } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -9106,7 +9160,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, - mask != nullptr, use_mask_opt, logit_softcap != 0); + mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type); vk_pipeline pipeline = nullptr; @@ -15590,38 +15644,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // It's straightforward to support different K/V dequant, but would - // significantly increase the number of pipelines - if (op->src[1]->type != op->src[2]->type) { + // mismatching K/V type is currently supported for coopmat2 only. + if (op->src[1]->type != op->src[2]->type && !coopmat2) { return false; } - switch (op->src[1]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - // supported in scalar and coopmat2 paths - break; - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: - //case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: - //case GGML_TYPE_Q6_K: - //case GGML_TYPE_IQ1_S: - //case GGML_TYPE_IQ1_M: - //case GGML_TYPE_IQ2_XXS: - //case GGML_TYPE_IQ2_XS: - //case GGML_TYPE_IQ2_S: - //case GGML_TYPE_IQ3_XXS: - //case GGML_TYPE_IQ3_S: - //case GGML_TYPE_IQ4_XS: - - default: + auto fa_kv_ok = [coopmat2](ggml_type t) { + switch (t) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_0: + return true; + case GGML_TYPE_Q1_0: + return coopmat2; + default: + return false; + } + }; + if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 6f349246915..efed3a73e22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32; layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 10) const uint32_t Flags = 0; layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; +// ggml_type enumerant for K/V +layout (constant_id = 12) const uint32_t FaTypeK = 0; +layout (constant_id = 13) const uint32_t FaTypeV = 0; +// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4). +layout (constant_id = 14) const uint32_t FaBlockBytesK = 2; +layout (constant_id = 15) const uint32_t FaBlockBytesV = 2; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 0ea181342ce..8a7bbaeb92c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -17,8 +17,57 @@ #extension GL_EXT_null_initializer : enable #include "types.glsl" -#include "dequant_funcs_cm2.glsl" #include "flash_attn_base.glsl" +#include "dequant_funcs_cm2.glsl" + +// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { + uint8_t raw[FaBlockBytesK]; +}; +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V { + uint8_t raw[FaBlockBytesV]; +}; + +uint fa_block_elems(uint ty) { + switch (ty) { + case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) + case 1u: return 1u; // GGML_TYPE_F16 + case 2u: return uint(QUANT_K_Q4_0); + case 3u: return uint(QUANT_K_Q4_1); + case 6u: return uint(QUANT_K_Q5_0); + case 7u: return uint(QUANT_K_Q5_1); + case 8u: return uint(QUANT_K_Q8_0); + case 41u: return uint(QUANT_K_Q1_0); + default: + return 1u; + } +} + +float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} + +float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if BLOCK_SIZE > 1 -#define DECODEFUNC , DEQUANTFUNC -#else -#define DECODEFUNC -#endif - // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) @@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c } void main() { -#ifdef NEEDS_INIT_IQ_SHMEM - init_iq_shmem(gl_WorkGroupSize); -#endif - init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); @@ -107,10 +146,10 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if BLOCK_SIZE > 1 - tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); - tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); -#endif + const uint bs_k = fa_block_elems(FaTypeK); + const uint bs_v = fa_block_elems(FaTypeV); + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); @@ -120,10 +159,12 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if BLOCK_SIZE == 1 - k_stride &= ~7; - v_stride &= ~7; -#endif + if (bs_k == 1u) { + k_stride &= ~7; + } + if (bs_v == 1u) { + v_stride &= ~7; + } m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); @@ -230,7 +271,13 @@ void main() { coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. + const bool k_use_decode = (bs_k > 1u); + if (k_use_decode) { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + } else { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); + } S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -291,7 +338,12 @@ void main() { coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); + const bool v_use_decode = (bs_v > 1u); + if (v_use_decode) { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + } else { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); + } L = eM*L + rowsum; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ff836615330..6f2a929c40c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -641,20 +641,17 @@ void process_shaders() { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } + if (fp16) { +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); +#endif + } + for (const auto& tname : type_names) { if (tname == "bf16") continue; if (fp16) { -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); - } -#endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", From 9623c1203b91da1467c1f3692c713a1f09dfa8c4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 1 May 2026 23:55:01 +0900 Subject: [PATCH 235/249] ggml-webgpu: Fix vectorized handling in mul-mat and mul-mat-id (llama/22578) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix vectorized condition of mul-mat-fast pipeline and add vectorized variant to mul-mat-id * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 651c9cbcdf6..cff93b8d170 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1779,12 +1779,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); @@ -2143,6 +2143,9 @@ class ggml_webgpu_shader_lib { // variant suffix for src1 type variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); From f2ce24fa5c946d7cbc42f52947428c2eba299393 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Fri, 1 May 2026 22:39:23 +0530 Subject: [PATCH 236/249] hexagon: enable non-contiguous row tensor support for unary ops (llama/22574) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 +- ggml/src/ggml-hexagon/htp/hvx-exp.h | 4 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 110 ++++++++++++++++++------- 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 17ac083f4ea..6bb073102c0 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2421,8 +2421,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return false; } - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + // TODO: add support for non-contiguous elements within a row + if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h index 84e4836dc92..e71ec4909a6 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.h +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -17,7 +17,7 @@ #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x42B16666) // 88.7 +#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228 #define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { @@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict HVX_Vector vec_out = Q6_V_vzero(); static const float kInf = INFINITY; - static const float kMaxExp = 88.7f; + static const float kMaxExp = 88.7228f; const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector inf = hvx_vec_splat_f32(kInf); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 03eccfd55e3..819cdc49bd9 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -26,8 +26,8 @@ struct htp_unary_context { const uint8_t * data_src0; uint8_t * data_dst; - size_t src0_row_size; - size_t dst_row_size; + size_t src0_data_row_size; // actual data bytes per row + size_t dst_data_row_size; // actual data bytes per row size_t src0_row_size_aligned; size_t dst_row_size_aligned; @@ -41,6 +41,40 @@ struct htp_unary_context { uint32_t nc; }; +// Convert flat row index to DDR byte offset using the tensor's actual strides. +// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3 +static inline size_t unary_row_offset(uint32_t ir, + uint32_t ne1, uint32_t ne2, + size_t nb1, size_t nb2, size_t nb3) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + return i1 * nb1 + i2 * nb2 + i3 * nb3; +} +// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice +// boundary of src and dst so the nb1 stride stays valid for all rows. +static inline uint32_t unary_block_size(uint32_t ir, + uint32_t end_row, + uint32_t block, + bool src_contig, + bool dst_contig, + uint32_t src_ne1, + uint32_t dst_ne1) { + uint32_t limit = MIN(block, end_row - ir); + + if (!src_contig) { + const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1; + limit = MIN(limit, src_slice_end - ir); + } + + if (!dst_contig) { + const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1; + limit = MIN(limit, dst_slice_end - ir); + } + + return limit; +} + #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ const uint32_t ne01 = src->ne[1]; \ @@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * int32_t * op_params = octx->op_params; uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const size_t src0_row_size = uctx->src0_row_size; - const size_t dst_row_size = uctx->dst_row_size; + const size_t src0_data_row_size = uctx->src0_data_row_size; + const size_t dst_data_row_size = uctx->dst_data_row_size; const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; @@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * size_t src0_spad_half_size = uctx->src0_spad_half_size; size_t dst_spad_half_size = uctx->dst_spad_half_size; - const int BLOCK = uctx->block; + // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride + // 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every + // transfer stays within a nb1-uniform region. Skipped for contiguous tensors. + const bool src0_contig = (nb02 == (size_t)ne01 * nb01) && + (nb03 == (size_t)ne02 * nb02); + const bool dst_contig = (nb2 == (size_t)ne1 * nb1) && + (nb3 == (size_t)ne2 * nb2); + const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01); + const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1); + const uint32_t BLOCK = MIN(src0_max_block, dst_max_block); if (BLOCK == 0) { FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", octx->src0_spad.size_per_thread, src0_row_size_aligned); @@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue * dma_queue = octx->ctx->dma[ith]; - for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { - const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) - dma_queue_push_vtcm_to_ddr(dma_queue, + dma_queue_push(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), - dst_row_size, dst_row_size_aligned, 0); + nb1, dst_row_size_aligned, dst_data_row_size, 0); - dma_queue_push_ddr_to_vtcm(dma_queue, - dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)), - src0_row_size_aligned, src0_row_size, block_size); + const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), + src0_row_size_aligned, nb01, src0_data_row_size, block_size); + ir += block_size; } - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { - const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + for (uint32_t ir = src0_start_row; ir < src0_end_row; ) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); float * dst_spad = (float *) dma_queue_pop(dma_queue).src; float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; @@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * break; } - dma_queue_push_vtcm_to_ddr(dma_queue, - dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), - dst_row_size, dst_row_size_aligned, block_size); + const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3); + dma_queue_push(dma_queue, + dma_make_ptr(data_dst + dst_off, dst_spad), + nb1, dst_row_size_aligned, dst_data_row_size, block_size); // prefetch N+2 loop iteration if any - const uint32_t pref_block = (ir + BLOCK * 2); - if (pref_block < src0_end_row) { - const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); - dma_queue_push_ddr_to_vtcm(dma_queue, - dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)), - src0_row_size_aligned, src0_row_size, pref_block_size); + const uint32_t next_ir = ir + block_size; + if (next_ir < src0_end_row) { + const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const uint32_t pref_ir = next_ir + next_block_size; + if (pref_ir < src0_end_row) { + const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + } } + ir += block_size; } dma_queue_flush(dma_queue); @@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); - const size_t src0_row_size = src0->nb[1]; - const size_t dst_row_size = dst->nb[1]; + const size_t src0_data_row_size = src0->ne[0] * sizeof(float); + const size_t dst_data_row_size = dst->ne[0] * sizeof(float); - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size @@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .data_src0 = (const uint8_t *)src0->data, .data_dst = (uint8_t *)dst->data, - .src0_row_size = src0_row_size, - .dst_row_size = dst_row_size, + .src0_data_row_size = src0_data_row_size, + .dst_data_row_size = dst_data_row_size, .src0_row_size_aligned = src0_row_size_aligned, .dst_row_size_aligned = dst_row_size_aligned, From 4861a3eeb5cb86df2de29c38c488e44d8dc9f6ca Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Fri, 1 May 2026 20:29:13 -0700 Subject: [PATCH 237/249] hexagon: hmx flash attention (llama/22347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hmx: extract shared interleave headers and unify matmul batched * hmx: add HMX-accelerated flash attention for prefill * hmx: replace asm wrappers with Q6_ intrinsics in hmx-utils.h Switches three single-instruction helpers from inline asm to the matching Q6_ intrinsics, matching the style established by aizip f8737609a and used by the upstream PR #21554 hmx-matmul-ops.c rewrite: hmx_set_output_scales asm "bias=mxmem2" -> Q6_bias_mxmem2_A hmx_load_tile_pair_fp16 asm packet -> Q6_activation_hf_mxmem_RR + Q6_weight_hf_mxmem_RR hmx_consume_accumulator_fp16 asm "mxmem=acc" -> Q6_mxmem_AR_after_hf hmx_load_tiles_fp16 stays on inline asm: it uses ":deep" activation streaming, and the mixed Q6_activation_hf_mxmem_RR_deep + non-deep Q6_weight_hf_mxmem_RR pair fails the HMX backend constraint check ("activate weight pair (1) exceeds limit (1)"). The asm bundle keeps both halves in one VLIW packet and avoids the diagnostic. Functionally equivalent — same instructions emitted; the Q6_ intrinsics just give the compiler more visibility for scheduling. * hmx: drop the duplicate interleave_fp16_weight_chunk_to_tiles * hmx: apply upstream optimization to hmx-flash-attn-ops.c apply restrict, __builtin_assume, and pointer accumulation to the three HMX workers (qk_dot, o_update, o_norm) and the matching inline HMX loops in op_hmx_flash_attn_ext. * hmx: unify interleave helper * hmx: multi-thread Q load / O store and enable prefill FA dispatch Extract inline Q-load and O-store loops into worker_pool-parallel helpers (fa_phase_q_load, fa_phase_o_store) so HVX threads split the F32↔F16 conversion work across row ranges. Also relax the softmax threading gate from n_row_vec_cnt >= n_threads to >= 2, which was unnecessarily forcing single-thread fallback when n_rows_g < 512. On the dispatch side, remove the ne[2] != 1 guard that blocked multi-head (prefill) FA from reaching the HTP backend — GQA is already handled internally by both the HMX and HVX flash-attention paths. * hmx: relax matmul pipeline gate to cover k > n shapes (e.g. FFN_down) * hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup) * hmx: Add an asm memory clobber at the phase boundary to prevent reorder bug * [experimental]: fp16 softmax (EXP2_HF) to accelerate fa Bake log2(e) into qk_scale and use hvx_exp2_hf directly for P and m_diff (base-2 consistent, matches htp-ops-lib). ~22 ALU ops for 64 lanes vs ~44 for the F32 round-trip path. * hmx flash-attn: refine cost model coefficients based on profiling data * hmx flash-attn: replace asm clobber with targeted volatile reads on vtcm_d_tiles * hmx flash-attn: fix prefill correctness (dst indexing, softmax reduce, V stride) * hmx flash-attn: fix p_tiles dual-tile OOB race; enable MT + pipeline * hmx flash-attn: preserve additive mask bias in no-ALiBi fast path The no-ALiBi fast path (max_bias==0) was skipping mask add entirely on the assumption that mask values are only {0, -inf}. This is wrong when the mask carries additive positional bias — those terms were silently dropped. Keep the slope-mul skip (slope≡1.0) but add mask back so the bias survives; vmux still clamps below -16 to -inf. Also add HMX FA coverage to test-backend-ops: prefill shapes (nb=64, nb=32) × {mask on/off} × {ALiBi on/off} × {softcap on/off}, F16 KV, hs ∈ {64, 128}. * hmx: fix softcap+EXP2_HF interaction, tighten matmul pipeline gate, add FA tests - flash-attn: when EXP2_HF is on AND logit_softcap is active, fold log2(e) into the post-tanh multiplier (v_cap) instead of pre-baking it into qk_scale. Pre-baking shifted the tanh knee from x≈c to x≈c/log2(e) and produced numerically wrong softcapped outputs whenever both knobs were enabled. - flash-attn softmax (fa_softmax_thread): replace the union+memcpy scalar extract pattern with HVX vmux-based per-row accumulators on rowmax/rowsum. Add hvx_vec_get_f16 helper in hvx-base.h. Functional parity, less scalar code, clearer hf/qf16 lane-format contract. - matmul (hmx_mat_mul_permuted_qk_0_d16a32): pick pipeline vs sequential layout based on whether the chunker actually yields >=2 n-chunks, instead of the static (m>=128 && n>=256) gate. Avoids paying for output double-buffer + worker dispatch when there is no HMX/HVX overlap to gain (e.g. shapes that collapse to one n-chunk). - tests: add HMX flash-attention coverage over the {mask, ALiBi (max_bias), logit_softcap} cross-product for the prefill path — head_dim 64/128, GQA 4×4, kv=512/nb=64 plus a kv=113/nb=32 non-aligned case. * [Help Wanted]: refactor D matrix computation into separate function for clarity and maintainability * format code * hexagon: looks like -O3 is causing issues with the large code base, switch to -O2 and -flto instead * hexagon: use hex_ prefix for swap_ptr * hexagon: move vtcm_seq_alloc into vtcm-utils.h More vtcm allocator updates are coming so it makes sense to start the separate hdr for it. * hmx-utils: add hmx_prefix for layout converters * hmx-mm: move main hmx_mm functions to the end, remove unused fwd decls, etc * hmx-mm: remove unused qweight_fetch_task_state_t and minor alignment fixes * hmx-fa: minor alignment fixes * hmx-fa: move hmx_flash_atten into hmx-ops.h * hmx-fa: remove redundant workpool pointer in the hmx_fa_ctx, plus minor alignment updates * hmx-fa: minor alignment and simplifications * hexagon: move FA_EXP_F16 option to hostside CMake file * hmx-fa: use hvx_vec_splat_f16 instead of fp16_to_bits * hmx-fa: add hvx_splat_u16/u8 and use that in the fa instead custom hvx_fill * hmx-fa: some more alignment updates in the core fa function * hmx-fa: keep slopes in vtcm in fp16 Saves malloc/free and removes the need for float -> fp16 downcast on every use. * hexagon: consistent noinline usage (after static) * hex-hmx: consistent use FARF_HIGH to enable debug output * hmx-utils: no need for always_inline attr * hex-hmx: consistent noinline usage (static noinline ...) * hex-hmx: simplify init_col_scales * hexagon: fix editorconfig errors * hmx-mm: minor alignment fixes --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/CMakeLists.txt | 3 +- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 3 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 7 + .../ggml-hexagon/htp/cmake-toolchain.cmake | 10 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 14 +- ggml/src/ggml-hexagon/htp/hex-utils.h | 6 + .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 1840 +++++++++++++++++ ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 1435 +++++++------ ggml/src/ggml-hexagon/htp/hmx-ops.h | 3 + ggml/src/ggml-hexagon/htp/hmx-utils.h | 192 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 6 + ggml/src/ggml-hexagon/htp/hvx-copy.h | 37 +- ggml/src/ggml-hexagon/htp/vtcm-utils.h | 16 + 13 files changed, 2798 insertions(+), 774 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/vtcm-utils.h diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index f3a583543c6..b82bae0c103 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) -option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 6bb073102c0..df4ed101464 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2254,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - if (dst->ne[2] != 1 || dst->ne[3] != 1) { - // FA during prompt still needs work + if (dst->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 8bd528478ba..7c9e4cda5f1 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + # HMX acceleration: available on v73+ architectures set(HTP_HMX_VERSIONS v73 v75 v79 v81) list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) @@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE hmx-queue.c hmx-matmul-ops.c + hmx-flash-attn-ops.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( hmx-matmul-ops.c + hmx-flash-attn-ops.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index 7fa236e328f..ed5c198468c 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") #Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index d296a322589..d95df6ac9d5 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,13 +17,14 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" +#include "hmx-ops.h" // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. -static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) { *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); } @@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } +#ifdef HTP_HAS_HMX + // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + int ret = hmx_flash_attn_ext(octx); + if (ret == HTP_STATUS_OK) { + return ret; + } + // VTCM too small or other failure -> fall through to HVX path + } +#endif + struct htp_fa_context factx; factx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 329249e11da..6239ceff4b4 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) { return a > b ? a : b; } +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c new file mode 100644 index 00000000000..8a6d7c14edf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -0,0 +1,1840 @@ +// HMX-accelerated Flash Attention for prefill (neq1 >= 32). +// Ported from htp-ops-lib/src/dsp/ops/flash_attn.c, adapted to the htp/ codebase. + +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "hex-dma.h" +#include "hmx-profile.h" +#include "hmx-queue.h" +#include "hmx-utils.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-dump.h" +#include "hvx-reduce.h" +#include "hvx-utils.h" +#include "vtcm-utils.h" +#include "worker-pool.h" + +// ============================================================================ +// Constants +// ============================================================================ + +// Tile constants from hmx-utils.h +// HMX_FP16_TILE_N_ROWS = 32 +// HMX_FP16_TILE_N_COLS = 32 +// HMX_FP16_TILE_N_ELMS = 1024 +// HMX_FP16_TILE_SIZE = 2048 + +// ============================================================================ +// Dynamic block size computation (GQA-aware) +// ============================================================================ + +// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration. +// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. +// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales +// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { + const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); + const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] + const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong + const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved + const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved + const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] + const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br] + const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); // m, l, etc. + const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096); + const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128); + + return q_tile_size * 1 // Q tiles + + o_tile_size * 2 // O ping-pong + + k_dma_size * 2 // K DMA x2 + + v_dma_size * 2 // V DMA x2 + + k_tile_size * 1 // K tiles + + v_tile_size * 1 // V tiles + + s_tile_size * 2 // S + P + + d_tile_size * 1 // D (diagonal matrix) + + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum + + row_vec_size * 2 * n_threads // per-thread softmax row scratch + + m_buf_size * 1 // mask VTCM buffer [Br rows] + + slopes_size // Slopes + + 256 * 2; // HMX scales (id + qk) +} + +// ============================================================================ +// FP16 exp2 polynomial (ported from htp-ops-lib/include/dsp/hvx_math.h) +// ============================================================================ +// 5th-order Horner polynomial for exp2(x) in qf16/hf16 domain. Input must be +// ≤ 0 (safe softmax invariant — overflow handling omitted). ~18 ALU ops per +// 64 fp16 lanes, fully parallel across HVX threads (no scatter/gather engine). +// Replaces the F32 round-trip (qf16→f32→exp→f32→f16, ~44 ops for 2×32 lanes). +static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) { + const HVX_Vector zero_v = Q6_V_vzero(); + const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5 + + // k = round_toward_neg_inf(x); f = (float)k; frac = x - f + HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v)); + HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16 + HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16 + + HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16 + + // Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0 + HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0 + + // Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k + y = Q6_Vhf_equals_Vqf16(y); + HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11); + y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp); + HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp); + y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10); + return Q6_V_vmux_QVV(q_underflow, zero_v, y); +} + +#define FA_MIN_KV_BLOCKS 3 + +// Cost-based (Br, Bc) search for flash attention with pipeline constraint. +// +// VTCM model (same as before): +// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc +// +// Cost model (minimization objective): +// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) +static int hmx_fa_find_chunk_size(size_t * Br_out, + size_t * Bc_out, + size_t gqa_factor, + size_t DK, + size_t DV, + size_t qo_len, + size_t kv_len, + size_t vtcm_budget, + size_t n_threads) { + const size_t T = HMX_FP16_TILE_N_ROWS; // 32 + const size_t br_unit = hmx_ceil_div(T, gqa_factor); + // Bc must be a multiple of 64 so that n_tiles_per_bc is even. The softmax + // P-tile write uses a dual-tile pattern (vshuff + two stores 16 slots apart) + // that would race across r0 blocks if the last dual-tile is half-occupied. + // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. + const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 + const size_t fp16 = sizeof(__fp16); + + // Approximate per-unit VTCM costs (without per-buffer alignment padding). + const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors + const size_t per_gbr2 = fp16; // D diagonal matrix + const size_t per_bc = + 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + const size_t per_gbr_bc = 2 * fp16; // S + P + + const size_t overhead = 256 * 2 + 13 * 4096; + + if (vtcm_budget <= overhead) { + return -1; + } + const size_t usable = vtcm_budget - overhead; + + // Br_max: largest Br aligned to br_unit that does not exceed qo_len. + const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit; + + // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. + // Only relax when kv_len is too short to form enough blocks. + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); + const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : + (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); + // Cost coefficients calibrated from profiling + const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store + const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers + + size_t best_cost = SIZE_MAX, best_mn = 0; + size_t best_Br = 0, best_Bc = 0; + + for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) { + const size_t g_br = hex_align_up(gqa_factor * Br, T); + + // g_br-dependent VTCM cost: g_br * per_gbr + g_br² * per_gbr2 + const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2; + if (gbr_cost >= usable) { + if (Br == br_unit) { + break; + } + continue; + } + + // Analytically solve for max Bc: + // remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16_mask) + // The Br * fp16 term accounts for the VTCM mask buffer [Br × Bc]. + const size_t remain = usable - gbr_cost; + const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16; + size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit); + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + // Exact VTCM verification (alignment padding may push over budget) + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + Bc -= bc_unit; + } + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + const size_t q_blocks = (qo_len + Br - 1) / Br; + const size_t kv_blocks = (kv_len + Bc - 1) / Bc; + const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); + const size_t mn = Br * Bc; + + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_Br = Br; + best_Bc = Bc; + } + + if (Br == br_unit) { + break; + } + } + + if (best_Br == 0) { + return -1; + } + + *Br_out = best_Br; + *Bc_out = best_Bc; + return 0; +} + +// ============================================================================ +// Tile interleave / extract helpers +// ============================================================================ + +// transpose scatter offsets moved to hmx-utils.h as hmx_transpose_scatter_offsets + +// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6 +// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile +static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = { + 0 * 136, 0 * 136 + 6, + 1 * 136, 1 * 136 + 6, + 2 * 136, 2 * 136 + 6, + 3 * 136, 3 * 136 + 6, + 4 * 136, 4 * 136 + 6, + 5 * 136, 5 * 136 + 6, + 6 * 136, 6 * 136 + 6, + 7 * 136, 7 * 136 + 6, + 8 * 136, 8 * 136 + 6, + 9 * 136, 9 * 136 + 6, + 10 * 136, 10 * 136 + 6, + 11 * 136, 11 * 136 + 6, + 12 * 136, 12 * 136 + 6, + 13 * 136, 13 * 136 + 6, + 14 * 136, 14 * 136 + 6, + 15 * 136, 15 * 136 + 6, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, +}; + +// hmx_interleave_rows_to_tiles and hmx_interleave_cols_to_tiles are in hmx-utils.h + +// ============================================================================ +// HMX Flash Attention context (GQA-merged) +// ============================================================================ + +struct hmx_fa_context { + const struct htp_ops_context * octx; + bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + uint32_t n_threads; + + // Op parameters + float scale; + float max_bias; + float logit_softcap; + uint32_t n_head_log2; + float m0, m1; + + // Dimensions + uint32_t DK, DV; + uint32_t n_kv; // kv_len + uint32_t n_kv_heads; // number of KV heads + uint32_t n_heads; // number of Q heads + uint32_t G; // GQA factor = n_heads / n_kv_heads + uint32_t n_kv_blocks; + uint32_t neq1; // Q token count + + // Types + bool is_q_fp32; + bool is_dst_fp32; + + // Dynamic block sizes + uint32_t Br; // Q tokens per block (before GQA expansion) + uint32_t Bc; + uint32_t g_br; // hex_align_up(G * Br, 32) - actual tile row dim + + // VTCM buffers (allocated by vtcm_seq_alloc) + __fp16 * vtcm_q_tiles; // Q tile format [g_br, D] + __fp16 * vtcm_o_tiles[2]; // O ping-pong [g_br, D] + __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] + __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] + __fp16 * vtcm_k_tiles; // K tiles (transposed) + __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] + __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] + __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] + HVX_Vector * vtcm_m_vec; // Row max [g_br] + HVX_Vector * vtcm_l_vec; // Row sum [g_br] + HVX_Vector * vtcm_s_rowmax; // Softmax intermediate [g_br] + HVX_Vector * vtcm_p_rowsum; // Softmax intermediate [g_br] + HVX_Vector * vtcm_row_bufs; // Per-thread softmax row scratch [n_threads][2][Bc/64] + uint8_t * vtcm_hmx_scales_id; // HMX output scales (identity) + uint8_t * vtcm_hmx_scales_qk; // HMX output scales (qk_scale) + __fp16 * vtcm_mask_buf; // VTCM mask buffer [Br × m_line], DMA'd per KV block + __fp16 * vtcm_slopes; // ALiBi slopes [g_br] + size_t row_buf_stride; // HVX vectors per row buffer (Bc/64) + size_t mask_buf_row_stride; // elements (__fp16) per row in mask buffer + bool mask_broadcast; // true when mask->ne[2] == 1 (head-independent, single 2D DMA) +}; + +// ============================================================================ +// Multi-thread K interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; +} fa_k_int_args_t; + +static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_k_int_args_t * args = (fa_k_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); // ensure even (row pairs) + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK, + (int) args->src_stride, start, end); +} + +static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads); + } else { + fa_k_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread V interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; + size_t n_col_tiles; +} fa_v_int_args_t; + +static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_v_int_args_t * args = (fa_v_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + (int) args->src_stride, (int) args->n_col_tiles, start, end); +} + +static void fa_phase_v_interleave(struct hmx_fa_context * factx, + int kv_rows, + size_t src_stride, + size_t buf_idx, + size_t n_col_tiles) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads); + } else { + fa_v_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread Q load phase: read Q[G × neq1, DK] from DDR, convert F32→F16 +// (or deal F16 pairs), and write interleaved into vtcm_q_tiles. +// Each thread owns a disjoint range of row pairs; writes target distinct tile +// slots (r0 selects tile row, r1 selects intra-tile slot), so there is no +// write conflict. Padding fill (when n_rows_g < g_br) is done single-threaded +// by the caller before dispatching. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * q; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_q_load_args_t; + +static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { + fa_q_load_args_t * args = (fa_q_load_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DK = factx->DK; + + // Partition row pairs across threads. Keep each thread's start even so r/r+1 + // are always in the same thread's range. + const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * q = args->q; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; r += 2) { + const bool next_row_valid = (r + 1) < n_rows_g; + + const size_t q_idx0 = (r + 0) / G; + const size_t h_idx0 = (r + 0) % G; + const size_t q_idx1 = (r + 1) / G; + const size_t h_idx1 = (r + 1) % G; + + const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; + const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] + + (kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) : + NULL; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + __fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK; + + if (factx->is_q_fp32) { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 32; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1); + + HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS); + out_tile[r1 / 2] = v_hf; + } + } else { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 64; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + + __fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_out1 = pv_out0 + 16; + + *pv_out0 = Q6_V_lo_W(vp); + *pv_out1 = Q6_V_hi_W(vp); + } + } + } +} + +static void fa_phase_q_load(struct hmx_fa_context * factx, + const struct htp_tensor * q, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g }; + // Require >= 2 row pairs per thread so partitioning is worthwhile. + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads); + } else { + fa_q_load_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread O store phase: read O tiles from VTCM, convert F16->F32 (or +// deal F16 pairs), and write to strided DDR dst tensor. Each thread owns a +// disjoint row range; writes target distinct dst rows (different q_idx/h_idx +// pairs produced by r/G and r%G), so there is no write conflict. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * dst; + const __fp16 * o_tile_src; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_o_store_args_t; + +static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { + fa_o_store_args_t * args = (fa_o_store_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DV = factx->DV; + + const size_t rows_per_t = hmx_ceil_div(n_rows_g, n); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * dst = args->dst; + const __fp16 * o_tile_src = args->o_tile_src; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; ++r) { + const size_t q_idx = r / G; + const size_t h_idx = r % G; + + // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> + // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. + uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] + + (q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3]; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV; + + if (factx->is_dst_fp32) { + float * out = (float *) dst_row; + for (uint32_t d = 0; d < DV / 32; ++d) { + const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS); + HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp); + } + } + } else { + __fp16 * out = (__fp16 *) dst_row; + for (uint32_t d = 0; d < DV / 64; ++d) { + const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_in1 = pv_in0 + 16; + HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp); + } + } + } + } +} + +static void fa_phase_o_store(struct hmx_fa_context * factx, + const struct htp_tensor * dst, + const __fp16 * o_tile_src, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g }; + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads); + } else { + fa_o_store_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread softmax phase + serial m/l update + build_D +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + size_t kv_rows; + size_t n_rows_g; + size_t n_col_tiles; + size_t n_tiles_per_bc; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + uint32_t Bc; + uint32_t G; + uint32_t kv_head; + uint32_t kv_start; + uint32_t q_start; + uint32_t ib3; + bool has_alibi; // true when max_bias != 0 (need slope * mask + add) + + // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) + // slope[r] = 1.0 when max_bias == 0 (no ALiBi) + // Pointer into hmx_fa_context.vtcm_slopes (sized to g_br) + __fp16 * slopes; + + // Mask info (preloaded before softmax) + const struct htp_tensor * mask; + const __fp16 * mask_vtcm; // VTCM mask buffer base (NULL = DDR fallback) + size_t mask_vtcm_row_stride; // elements (__fp16) per row in VTCM mask buffer +} fa_softmax_args_t; + +static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { + fa_softmax_args_t * args = (fa_softmax_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t kv_rows = args->kv_rows; + const size_t Bc = args->Bc; + const size_t G = args->G; + const size_t n_tiles_per_bc = args->n_tiles_per_bc; + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + + // Partition r_vec_idx across threads + const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n); + const size_t vec_start = i * vecs_per_t; + const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt); + + if (vec_start >= n_row_vec_cnt) { + return; + } + + // Per-thread row scratch: thread i uses bufs at offset i * 2 * stride + const size_t row_buf_stride = factx->row_buf_stride; + HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride; + HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride; + + const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff); + + // Per-row accumulators: each fp16 lane in a 64-lane vector holds one row's scalar. + // CONTRACT: lane bits must be IEEE fp16 (hf), never qf16 — qf16 uses a different + // bit layout, so a later hf-domain read would silently produce wrong values. + // Convert first via Q6_Vhf_equals_Vqf16(). For reference: vtcm_m_vec/vtcm_s_rowmax + // are hf; vtcm_l_vec is qf16 — don't mix them up. + + for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) { + HVX_Vector rowmax_acc_v = v_neg_inf; + HVX_Vector rowsum_acc_v = Q6_V_vzero(); + HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx]; + + for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) { + int r = r_vec_idx * 64 + r_vec_off; + if (r >= (int) hex_align_up(n_rows_g, 2)) { + break; + } + + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + __fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + + // Decode 2 rows from S tiles into per-thread row buffers + HVX_Vector * pv_row_buf0 = my_row_buf0; + HVX_Vector * pv_row_buf1 = my_row_buf1; + for (size_t c = 0; c < kv_rows; c += 64) { + const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_s_in1 = pv_s_in0 + 16; + + HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2); + *pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row); + *pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row); + } + + // Apply softcap if enabled (in F32 precision) + if (factx->logit_softcap != 0.0f) { + // When EXP2_HF is on, fold log2(e) into v_cap so the output lands in + // log2(e)-scaled space for the downstream exp2. log2(e) is kept OUT + // of qk_scale in this configuration (see scale setup) so tanh sees + // the physical QK/(√d·c) argument. + float cap = factx->logit_softcap; +#ifdef HMX_FA_USE_EXP2_HF + cap *= 1.44269504f; // log2(e) +#endif + const HVX_Vector v_cap = hvx_vec_splat_f32(cap); + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + + HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]); + HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32)); + HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32)); + t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap)); + t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap)); + my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi); + + HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]); + HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32)); + HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32)); + t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap)); + t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap)); + my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi); + } + } + + // Apply mask & compute rowmax(S) + // + // Optimizations over baseline: + // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), skip the + // slope multiplication — still add mask (additive bias) but + // avoid the mul_f16_f16. Saves 2 ops/dual-row vs ALiBi path. + // B. GQA mask row dedup: G consecutive Q rows share one mask row + // (qi = r / G). Reuse mask vector when qi is unchanged between + // row0 and row1 (saves ~75% of VTCM loads for G=4). + + // ALiBi slopes — only needed when has_alibi (scheme A) + HVX_Vector v_slope0, v_slope1; + if (args->has_alibi) { + v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + } + + const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) + + HVX_Vector v_s_rowmax0 = v_neg_inf; + HVX_Vector v_s_rowmax1 = v_neg_inf; + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + const size_t ne = hex_smin(kv_rows - c, 64); + HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16)); + + if (args->mask) { + HVX_Vector v_mask0, v_mask1; + + if (args->mask_vtcm) { + // Read mask from VTCM buffer (DMA'd per KV block). + // GQA dedup (scheme B): skip load when qi unchanged. + const size_t qi0 = (r + 0) / G; + v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); + v_mask1 = v_neg_inf; + if (r + 1 < (int) n_rows_g) { + const size_t qi1 = (r + 1) / G; + if (qi1 == qi0) { + v_mask1 = v_mask0; // scheme B: reuse — same mask row + } else { + v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + } + } + } else { + // Fallback: read mask directly from DDR (when mask->ne[2] > 1). + const struct htp_tensor * mask = args->mask; + const size_t q_idx0 = args->q_start + ((r + 0) / G); + const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const uint32_t im2_0 = h_idx0 % mask->ne[2]; + const uint32_t im3_0 = args->ib3 % mask->ne[3]; + + const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] + + im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c; + v_mask0 = *(const HVX_UVector *) m0_ptr; + v_mask1 = v_neg_inf; + + if (r + 1 < (int) n_rows_g) { + const size_t q_idx1 = args->q_start + ((r + 1) / G); + if (q_idx1 == q_idx0) { + // scheme B: same mask row in DDR path + v_mask1 = v_mask0; + } else { + const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const uint32_t im2_1 = h_idx1 % mask->ne[2]; + const uint32_t im3_1 = args->ib3 % mask->ne[3]; + const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + + im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c; + v_mask1 = *(const HVX_UVector *) m1_ptr; + } + } + } + + // Threshold: mask values below -16.0 are treated as -inf (causal mask). + HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); + HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); + + if (args->has_alibi) { + // ALiBi path: S += slope * mask (full mul + add) + HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); + HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + } else { + // No-ALiBi fast path (scheme A): slope≡1.0, skip the mul + // but still add mask (additive positional bias). vmux + // clamps mask < -16 to -inf as a numerical safeguard. + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf); + } + } else { + if (ne < 64) { + my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf); + } + } + + v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]); + v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]); + } + + v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0); + v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); + + // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. + // vror brings the target lane to lane 0, then extract + re-splat. + HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); + HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + + // HVX max — both operands are splats, so result is splat of m_new. + HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); + HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1); + + // Insert row r, r+1 rowmax into rowmax_acc_v via 2-byte-wide vmux. + // Byte ranges: lane0 = [r_vec_off*2 .. r_vec_off*2+1], lane1 shifted by 2. + // vsetq2 handles the n=128 corner case when r_vec_off reaches 62. + { + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v); + } + + // Compute P = exp(S - m_new), using HVX exp + const HVX_Vector v_zero = Q6_V_vzero(); + HVX_Vector v_p_rowsum0 = v_zero; + HVX_Vector v_p_rowsum1 = v_zero; + +#ifdef HMX_FA_USE_EXP2_HF + // FP16 exp2 polynomial path (matches htp-ops-lib flash_attn.c): + // P = exp2(S - m_new) + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); +#else + // F32 exp path: qf16 → f32 → exp → f32 → f16. Higher precision, + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0)); + HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0)); + HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi); + + HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); + HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1)); + HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1)); + HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi); +#endif + // Write P to tile format. Dual-tile pattern assumes Bc is a + // multiple of 64 (enforced by bc_unit=64 in hmx_fa_find_chunk_size), + // so both tile halves are always in the current r0 block. + __fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_p_out1 = pv_p_out0 + 16; + + HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2); + *pv_p_out0 = Q6_V_lo_W(vp_p_dual); + *pv_p_out1 = Q6_V_hi_W(vp_p_dual); + + HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf); + HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf); + + v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0))); + v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1))); + } + + HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0)); + HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1)); + { + // Both inputs are f32 splats, so the f32->f16 output is an fp16 splat. + HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf); + HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf); + + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v); + } + } + + factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v; + factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v; + } +} + +// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads). +// +// noinline: function boundary acts as a hard compiler barrier so the (size_t)addr scatter +// intrinsics inside cannot be hoisted past the call site. Mirrors the structural protection +// matmul gets for free via worker_pool function-pointer dispatch. Without this, the compiler +// can reorder the scatter past the subsequent hmx_queue_push and the HMX-queue worker thread +// reads stale VTCM (PPL → ~vocab-size). +static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx, + size_t n_rows_g, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + // Reuse s_rowmax buffer for exp(m_diff) — safe because softmax is fully complete + HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax; + + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + for (size_t i = 0; i < n_row_vec_cnt; ++i) { + HVX_Vector v_m_prev = factx->vtcm_m_vec[i]; + HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]); + HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr); + +#ifdef HMX_FA_USE_EXP2_HF + // Base-2 path: must match P = exp2(S - m_new) in fa_softmax_thread. + HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff)); +#else + HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff)); + HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff)); + HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff)); + HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi); +#endif + + HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff); + v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]); + + factx->vtcm_m_vec[i] = v_m_curr; + factx->vtcm_l_vec[i] = v_l_curr; + mvec_exp_m_diff[i] = v_exp_m_diff; + } + + // Build diagonal tile D = diag(exp(m_diff)) + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + for (size_t i = 0; i < n_row_tiles; ++i) { + const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64); + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — Q6_vscatter takes (size_t)addr; without this the + // compiler may not recognize the volatile read below as aliasing and + // could reorder it before the scatter, defeating the HW drain. + __asm__ __volatile__("" ::: "memory"); + // Per-tile drain: scatter regions are disjoint (stride > tile size), + // so a single drain at tile 0 does NOT retire later tiles' entries. + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Build D = diag(1/l) tile for the final O = D @ O normalization. +// +// noinline: same rationale as fa_ml_update_and_build_d — keeps Q6_vscatter from +// being hoisted past the subsequent hmx_queue_push at the o_norm call site. +static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + const HVX_Vector one = hvx_vec_splat_f32(1.0f); + + HVX_Vector v_content = Q6_V_vzero(); + for (size_t i = 0; i < n_row_tiles; ++i) { + if ((i % 2) == 0) { + HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]); + HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf); + HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l)))); + HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l)))); + v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi); + } else { + v_content = Q6_V_vror_VR(v_content, 64); + } + + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — see fa_ml_update_and_build_d for rationale. + __asm__ __volatile__("" ::: "memory"); + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Combined: multi-thread softmax -> barrier -> serial m/l update + build_D +static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx, + fa_softmax_args_t * sargs, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64); + + if (factx->n_threads > 1 && n_row_vec_cnt >= 2) { + uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt); + worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use); + } else { + fa_softmax_thread(1, 0, sargs); + } + // barrier implicit in worker_pool_run_func return + + fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br); +} + +// ============================================================================ +// HMX job structs and worker functions +// ============================================================================ + +typedef struct { + const __fp16 * q_tiles; + const __fp16 * k_tiles; + __fp16 * s_tiles; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_dot_tiles; // DK / 32 + size_t n_tiles_per_bc; + uint8_t * hmx_scales; +} hmx_fa_qk_job_t; + +static void hmx_fa_qk_dot_worker(void * data) { + hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_dot_tiles = job->n_dot_tiles; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const __fp16 * restrict q_tiles = job->q_tiles; + const __fp16 * restrict k_tiles = job->k_tiles; + __fp16 * restrict s_tiles = job->s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + __fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; + const __fp16 * o_prev; + const __fp16 * p_tiles; + const __fp16 * v_tiles; + const __fp16 * d_tiles; + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_row_tiles_g_br; + size_t n_tiles_per_bc; + size_t DV; +} hmx_fa_o_update_job_t; + +static void hmx_fa_o_update_worker(void * data) { + hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict p_tiles = job->p_tiles; + const __fp16 * restrict v_tiles = job->v_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + // D[r,r] @ O_prev[r,c] — only the diagonal tile + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + // P @ V (accumulate on same accumulator) + const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; // output (row-major tile layout) + const __fp16 * o_prev; // input (column-major tile layout) + const __fp16 * d_tiles; // diag(1/l) tiles + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + size_t DV; +} hmx_fa_o_norm_job_t; + +static void hmx_fa_o_norm_worker(void * data) { + hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } +} + +// Populate per-GQA-row ALiBi slopes for a given KV head. +// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. +// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). +// When max_bias == 0, all slopes are 1.0 (no ALiBi). +static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, + const struct hmx_fa_context * factx, + uint32_t kv_head, + size_t n_rows_g) { + if (factx->max_bias == 0.0f) { + for (size_t r = 0; r < n_rows_g; ++r) { + sargs->slopes[r] = 1.0f; + } + return; + } + + const uint32_t G = factx->G; + const uint32_t n_head_log2 = factx->n_head_log2; + const float m0 = factx->m0; + const float m1 = factx->m1; + + for (size_t r = 0; r < n_rows_g; ++r) { + const uint32_t h = kv_head * G + r % G; + sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + } +} + +// ============================================================================ +// Core HMX flash attention algorithm (GQA-merged) +// ============================================================================ + +int hmx_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL; + const struct htp_tensor * dst = octx->dst; + + struct htp_context * const ctx = octx->ctx; + + if (!ctx->hmx_enabled) { + return HTP_STATUS_NO_SUPPORT; + } + + // Dimensions + const uint32_t neq0 = q->ne[0]; // head_dim (DK) + const uint32_t neq1 = q->ne[1]; // n_tokens + const uint32_t neq2 = q->ne[2]; // n_heads + const uint32_t neq3 = q->ne[3]; // n_seqs + + const uint32_t nek0 = k->ne[0]; // head_dim + const uint32_t nek1 = k->ne[1]; // kv_len + + const uint32_t nev0 = v->ne[0]; // head_dim (DV) + + const uint32_t DK = neq0; + const uint32_t DV = nev0; + + // HMX requires head_dim to be multiple of 32 + if (DK % 32 != 0 || DV % 32 != 0) { + return HTP_STATUS_NO_SUPPORT; + } + if (neq1 < 32) { + return HTP_STATUS_NO_SUPPORT; + } + + // GQA factor + const uint32_t n_kv_heads = k->ne[2]; + const uint32_t G = neq2 / n_kv_heads; + + // Thread count for multi-thread HVX phases + const uint32_t n_threads = octx->n_threads; + + // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) + size_t Br, Bc; + const size_t vtcm_budget = ctx->vtcm_size; + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); + + const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + + FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + + // ======== Build context ======== + struct hmx_fa_context factx; + memset(&factx, 0, sizeof(factx)); + factx.octx = octx; + factx.n_threads = octx->ctx->n_threads; + factx.DK = DK; + factx.DV = DV; + factx.n_kv = nek1; + factx.n_kv_heads = n_kv_heads; + factx.n_heads = neq2; + factx.G = G; + factx.neq1 = neq1; + factx.Br = (uint32_t) Br; + factx.Bc = (uint32_t) Bc; + factx.g_br = (uint32_t) g_br; + factx.n_kv_blocks = n_kv_blocks; + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); + factx.use_pipeline = use_pipeline; + factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); + + // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) + float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f; + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + +#ifdef HMX_FA_USE_EXP2_HF + // Pre-bake log2(e) into qk_scale so HMX-produced S tiles are in log2(e)-scaled + // space. Then exp2(S - m) in the softmax equals base-e exp((S - m) / log2(e)), + // preserving ggml's base-e softmax semantics. Matches htp-ops-lib flash_attn.c. + // + // When softcap is active we cannot pre-bake log2(e) here — it would land inside + // the tanh argument and shift the softcap knee from x≈c to x≈c/log2(e), giving + // numerically wrong softcapped values. Instead fold log2(e) into the post-tanh + // multiplier (see softcap block: v_cap absorbs log2(e)). + if (logit_softcap == 0.0f) { + scale *= 1.44269504f; // log2(e) + } +#endif + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2)); + factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + + // ======== VTCM allocation (GQA-aware) ======== + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); + const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); + const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); + const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256); + const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096); + const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128); + + uint8_t * vtcm_cur = ctx->vtcm_base; + + factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes); + factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); + factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); + factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads); + factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector); + factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes); + factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16); + factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes); + + if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + // ======== Initialize HMX output scales ======== + // Identity scale (1.0) for O updates and normalization + hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00)); // 1.0 + + // QK scale embedded in HMX output + hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale)); + + // ======== Skip compute if profiling ======== + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // Profiling timers + TIMER_DEFINE(total); + TIMER_DEFINE(q_load); + TIMER_DEFINE(kv_dma); + TIMER_DEFINE(k_interleave); + TIMER_DEFINE(v_interleave); + TIMER_DEFINE(qk_dot); + TIMER_DEFINE(softmax); + TIMER_DEFINE(o_update); + TIMER_DEFINE(o_norm); + TIMER_DEFINE(o_store); + + TIMER_START(total); + + // ======== DMA setup ======== + dma_queue * const dma = ctx->dma[0]; + + // Padded row sizes for DMA + const size_t size_k_row = nek0 * sizeof(__fp16); + const size_t size_v_row = nev0 * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); + const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + + const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; + const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; + + // Q/O element size for Q load and O store + const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16); + + // ======== HMX lock strategy ======== + // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. + // Fallback: main thread holds the lock (original behavior). + if (!factx.use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + } + + // ======== Reusable job descriptors for pipeline ======== + hmx_fa_qk_job_t qk_job; + hmx_fa_o_update_job_t ou_job; + hmx_fa_o_norm_job_t on_job; + + // ======== Main loop: per batch, per KV head, per Q block ======== + for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) { + for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) { + const uint32_t ik2 = kv_head; + const uint32_t ik3 = ib3 / (neq3 / k->ne[3]); + const uint32_t iv2 = kv_head; + const uint32_t iv3 = ib3 / (neq3 / v->ne[3]); + + for (uint32_t q_start = 0; q_start < neq1; q_start += Br) { + const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start); + const size_t n_rows_g = n_q_rows * G; + const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS); + const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS; + + // ---- Load Q block [g_br, D] -> tiles, interleaving G heads ---- + TIMER_START(q_load); + if (n_rows_g < g_br) { + hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes); + } + fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(q_load); + + // ---- Initialize per-block state ---- + hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes); + hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes); + hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2); + + __fp16 * o_tile_prev = factx.vtcm_o_tiles[0]; + __fp16 * o_tile_curr = factx.vtcm_o_tiles[1]; + hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes); + + // ---- KV block loop with DMA double-buffering ---- + size_t buf_idx = 0; + + // Prefetch first KV block + if (factx.n_kv_blocks > 0) { + const uint32_t kv_rows0 = hex_smin(Bc, nek1); + + const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1], + size_k_row, kv_rows0); + + const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1], + size_v_row, kv_rows0); + } + + // Mask DMA: single 2D transfer of n_q_rows unique mask rows into VTCM buffer. + // Only when mask is head-broadcast (ne[2]==1); otherwise softmax reads DDR directly. + #define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \ + do { \ + has_mask_dma_var = false; \ + if (mask && factx.mask_broadcast) { \ + const uint32_t _im3 = ib3 % mask->ne[3]; \ + const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \ + (kv_start_val) * sizeof(__fp16); \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \ + (kv_rows_val) * sizeof(__fp16), n_q_rows); \ + has_mask_dma_var = true; \ + } \ + } while (0) + + #define MASK_DMA_POP(has_mask_dma_var) \ + do { \ + if (has_mask_dma_var) { \ + dma_queue_pop(dma); \ + } \ + } while (0) + + #define DMA_PREFETCH_KV(blk_val) \ + do { \ + if ((blk_val) < factx.n_kv_blocks) { \ + const uint32_t _ns = (blk_val) * Bc; \ + const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \ + size_t _nb = 1 - buf_idx; \ + const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \ + const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \ + } \ + } while (0) + + const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); + const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); + + if (factx.use_pipeline) { + // ================================================================== + // Pipeline path: HVX phases ‖ HMX queue worker + // ================================================================== + struct hmx_queue * hmx_q = ctx->hmx_queue; + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + // Wait for current KV DMA + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + // Push mask DMA for this block (single 2D DMA when broadcast) + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + + // ---- Phase 1: K_int(blk) ‖ O_update(blk-1) ---- + if (kv_blk > 0) { + // Submit O_update for previous block (HMX worker) + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS); + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + } + + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- + qk_job.q_tiles = factx.vtcm_q_tiles; + qk_job.k_tiles = factx.vtcm_k_tiles; + qk_job.s_tiles = factx.vtcm_s_tiles; + qk_job.n_row_tiles = n_row_tiles; + qk_job.n_col_tiles = n_col_tiles; + qk_job.n_dot_tiles = DK / 32; + qk_job.n_tiles_per_bc = n_tiles_per_bc; + qk_job.hmx_scales = factx.vtcm_hmx_scales_qk; + TIMER_START(qk_dot); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job)); + + // DMA push next block (non-blocking, before worker_pool) + DMA_PREFETCH_KV(kv_blk + 1); + + TIMER_START(v_interleave); + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + hmx_queue_pop(hmx_q); + TIMER_STOP(qk_dot); + + // ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ---- + // Pop mask DMA before softmax (ensures VTCM buffer is ready) + MASK_DMA_POP(has_mask_dma); + + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + buf_idx = 1 - buf_idx; + } // end KV block loop (pipeline) + + // Epilogue: O_update for last block + if (factx.n_kv_blocks > 0) { + const uint32_t last_blk = factx.n_kv_blocks - 1; + const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS); + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = last_cols; + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + + TIMER_START(o_update); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + hmx_queue_pop(hmx_q); + TIMER_STOP(o_update); + + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + } else { + // ================================================================== + // Fallback path: sequential with multi-thread HVX phases + // Main thread holds HMX lock, runs HMX inline. + // ================================================================== + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + DMA_PREFETCH_KV(kv_blk + 1); + + // K interleave (multi-thread HVX) + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + // QK dot (inline HMX on main thread) + TIMER_START(qk_dot); + { + const size_t n_dot_tiles = (size_t) (DK / 32); + const __fp16 * restrict q_base = factx.vtcm_q_tiles; + const __fp16 * restrict k_base = factx.vtcm_k_tiles; + __fp16 * restrict s_base = factx.vtcm_s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK; + const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK; + __fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } + } + TIMER_STOP(qk_dot); + + // Pop mask DMA + MASK_DMA_POP(has_mask_dma); + + // Softmax + build_D (multi-thread HVX + serial m/l update) + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + // V interleave (multi-thread HVX) + TIMER_START(v_interleave); + // FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout + // stride to match o_update's v_tile access. Using per-block n_col_tiles + // misplaces DV_tile 1..3 in the last partial KV block. + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + // O update (inline HMX on main thread) + TIMER_START(o_update); + { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict p_base = factx.vtcm_p_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + TIMER_STOP(o_update); + + buf_idx = 1 - buf_idx; + } // end KV block loop (fallback) + } + + // ---- Final normalization: O = diag(1/l) @ O ---- + TIMER_START(o_norm); + { + fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); + + // HMX: O_final = diag(1/l) @ O_prev + if (factx.use_pipeline) { + on_job.o_curr = o_tile_curr; + on_job.o_prev = o_tile_prev; + on_job.d_tiles = factx.vtcm_d_tiles; + on_job.hmx_scales = factx.vtcm_hmx_scales_id; + on_job.n_row_tiles = n_row_tiles; + on_job.n_row_tiles_g_br = n_row_tiles_g_br; + on_job.DV = DV; + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job)); + hmx_queue_pop(ctx->hmx_queue); + } else { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } + } + } + TIMER_STOP(o_norm); + + // ---- Store O block ---- + TIMER_START(o_store); + fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(o_store); + +#undef MASK_DMA_PUSH +#undef MASK_DMA_POP +#undef DMA_PREFETCH_KV + + } // end Q block loop + } // end KV head loop + } // end batch loop + + if (factx.use_pipeline) { + hmx_queue_suspend(ctx->hmx_queue); + } else { + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total), + TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave)); + FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax), + TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store)); +#endif + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 05e3c6c2b0f..2666a78a96a 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -28,6 +28,8 @@ #include "hmx-queue.h" #include "hmx-profile.h" +#include "vtcm-utils.h" + static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; @@ -43,40 +45,11 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, }; -// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. -// word[i] = i*128 maps K-row-pair i to byte offset i*128 in the tile. -// Column offset (n*4) is added at runtime. Only entries 0..15 are used (masked by predicate). -static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { - 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, - 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, - 16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128, - 24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128 -}; - // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes #define HMX_X4X2_SCALES_PER_BLK 8 #define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) #define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) -static inline void swap_ptr(void **p1, void **p2) { - void *t = *p1; - *p1 = *p2; - *p2 = t; -} - -typedef struct { - uint8_t *dst; - const uint8_t *src; - dma_queue *dma; - size_t n_rows; - size_t src_stride; // DDR row stride (full row_stride) - size_t dst_stride; // VTCM sub-block row stride - size_t quant_off; // quant byte offset in each DDR row - size_t quant_width; // quant bytes to copy per row - size_t scale_off; // scale byte offset in each DDR row - size_t scale_width; // scale bytes to copy per row -} qweight_fetch_task_state_t; - // Compute the byte stride of one row in x4x2 format. // Numerically equals ggml_row_size(type, k) when k is 256-aligned, because // x4x2 packing has the same density as block_q4_0 / block_q8_0. @@ -202,46 +175,6 @@ static int hmx_compute_chunks(size_t vtcm_total, return 0; } -// forward declaration – defined after transfer_activation_chunk_fp32_to_fp16 -void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride); - -// Scatter row-major FP16 weight (already in VTCM scratch) directly into transposed [K][N] tiles. -// vtcm_src: [n_cols][k] row-major fp16 in VTCM scratch buffer -// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 -static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst, - const __fp16 *restrict vtcm_src, - int n_cols, int k) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - assert(k % HMX_FP16_TILE_N_COLS == 0); - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - for (int r = 0; r < n_cols; r += 2) { - int ct = r / HMX_FP16_TILE_N_ROWS; // N-dimension tile index - int local_r = r % HMX_FP16_TILE_N_ROWS; // intra-tile row index - const bool next_row_valid = (r + 1) < n_cols; - - // Offset vectors for N-columns local_r and local_r+1, reused across K-tiles. - HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); - HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - - for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) { - int kt = c / HMX_FP16_TILE_N_COLS; - int tile_idx = ct * n_k_tiles + kt; - __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS; - - HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c); - HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1); - } - } -} - // --- x4x2 format dequantizers --- // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. @@ -303,8 +236,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. -static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( - const int8_t *quants_32, const __fp16 *scale) { +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); HVX_Vector v_scales = hvx_vec_splat_f16(*scale); HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); @@ -414,8 +346,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. - const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index @@ -658,12 +590,12 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; state.n_tot_tiles = n_tot_tiles; state.n_tiles_per_task = n_tiles_per_task; - state.dst = vtcm_dst; - state.src = (const uint8_t *)vtcm_src; - state.n_cols = n_cols; - state.k_block = k_block; - state.row_stride = row_stride; - state.weight_type = weight_type; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); } @@ -733,7 +665,7 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, job->n_dot_tiles = n_dot_tiles; } -// --- End async HMX matmul job --- +// output : fp16 -> f32p static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); @@ -807,295 +739,397 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); } -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { - return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; -} +// activations : fp32 -> fp16 -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { - return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; -} +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - const int r2 = hmx_matmul_batch_r2(params); - const int r3 = hmx_matmul_batch_r3(params); - return (const __fp16 *) ((const uint8_t *) params->permuted_weight + - (size_t) (dst_b2 / r2) * params->src0_nb2 + - (size_t) (dst_b3 / r3) * params->src0_nb3); -} + const bool next_row_valid = (r + 1) < n_rows; -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (const float *) ((const uint8_t *) params->activation + - (size_t) dst_b2 * params->src1_nb2 + - (size_t) dst_b3 * params->src1_nb3); -} + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (float *) ((uint8_t *) params->dst + - (size_t) dst_b2 * params->dst_nb2 + - (size_t) dst_b3 * params->dst_nb3); -} + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { - int ret = 0; - for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { - for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; } } - return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { - if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } - if (!params->m || !params->k || !params->n) { return -1; } - if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } - if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } - if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } - if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; - if (!hex_is_aligned(params->dst, VLEN) || - !hex_is_aligned(params->activation, VLEN) || - !hex_is_aligned(params->permuted_weight, VLEN)) { - return -1; - } +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - const int group_size = hmx_matmul_batch_r2(params); + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - if (group_size <= 1) { - FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); } +} - // Grouped path: reuse interleaved weight across all q_heads sharing a - // kv_head. Each q_head gets its own activation buffer in VTCM (so - // activation is loaded once per m_chunk and reused across all n_chunks), - // and each q_head is computed individually to avoid tile-major packing - // issues. m_chunk_n_rows is always a multiple of 32 (from - // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = params->k * sizeof(__fp16); +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); - // When the activation has a large stride (e.g. permuted Q tensor with - // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. - // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather - // strided rows into a contiguous block before the F32->F16 conversion. - const bool use_dma_activation = (params->act_stride > params->k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), params->m, params->n, - /*m_block_cost=*/(size_t) params->n, - /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); - } + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; - const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); +} - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; +// - if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { - FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); - } +#define FALLBACK_TO_STANDARD 1 - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); - FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, params->m, params->k, params->n, group_size, params->ne13, - m_chunk_n_rows, n_chunk_n_cols, - (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + Q6_bias_mxmem2_A((void *)col_scales); - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - TIMER_DEFINE(total); + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); - TIMER_START(total); + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } - const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + for (int k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); +static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, + float *restrict out, const float *restrict x, const uint8_t *restrict w, + int m, int k, int n, int weight_type) { + // assume k % 32 == 0 && n % 32 == 0 + const size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } - for (int b3 = 0; b3 < params->ne13; ++b3) { - for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { - const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + const size_t vtcm_budget = ctx->vtcm_size; - for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + const size_t K_BLOCK_SIZE = 1024; - // Pre-load activations for all heads in the group (once per m_chunk). - // When the source is strided (permuted Q), use 2D DMA to gather - // contiguous rows into a VTCM scratch buffer first, then HVX - // converts from the contiguous VTCM buffer. This avoids L2 cache - // thrashing from HVX loads at large strides. - TIMER_START(activation_load); - for (int g = 0; g < group_size; ++g) { - const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; - __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) params->k * sizeof(float); - const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - vtcm_f32_act, (int) n_rows, - params->k, params->k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - activation_chunk, (int) n_rows, - params->k, params->act_stride); - } - } - TIMER_STOP(activation_load); + // Fallback: if k doesn't need K-blocking, out-stationary has no advantage + const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; + if (k_iters_check <= 1) { + FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); + return FALLBACK_TO_STANDARD; + } - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; + // Dynamic M,N search via hmx_compute_chunks + const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) + + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) + const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) + + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) + const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - { - const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } + // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; + // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin + const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; + const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; + // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. + // From profiling: wt_dequant per element ≈ 1.5× activation load per element. + // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). + // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). + const size_t m_block_cost = (size_t) n * 3; + const size_t n_block_cost = (size_t) m * 2; + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + &N_BLOCK_SIZE, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); + // Compute precise buffer sizes from searched M,N and fixed K + const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); + const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < (size_t) params->n) { - const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; + if (total_vtcm > vtcm_budget) { + FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, + vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); + return -1; + } - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); + uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); + uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); + __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k); - swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, + M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - // Reuse the interleaved weight for every q_head in this GQA group - for (int g = 0; g < group_size; ++g) { - TIMER_START(hmx_core); - { - const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, - params->k / 32); - } - TIMER_STOP(hmx_core); + // initialize eye tile (32x32 identity matrix) + { + HVX_Vector v; + v = Q6_V_vzero(); + v = Q6_Vw_vinsert_VwR(v, 0x3c000000); + v = Q6_V_vror_VR(v, VLEN - 4); + v = Q6_Vw_vinsert_VwR(v, 0x00003c00); + for (int i = 0; i < 16; ++i) { + ((HVX_Vector *) vtcm_eye_tile)[i] = v; + v = Q6_V_vror_VR(v, VLEN - 8); + } + } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - TIMER_START(output_store); - { - float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); - } - TIMER_STOP(output_store); - } + TIMER_DEFINE(fetch); + TIMER_DEFINE(act_load); + TIMER_DEFINE(wt_dequant); + TIMER_DEFINE(core); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { + size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); + for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { + size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + + const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + + for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { + const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + + TIMER_START(fetch); + // fetch activation block into VTCM + { + const float *activation_block = x + mr * k + kk; + + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_scratch1, activation_block), + k_blk_sz * sizeof(float), + k * sizeof(float), + k_blk_sz * sizeof(float), + m_blk_sz); + } + + // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + { + const int blk_start = kk / QK_Q4_0x4x2; + const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); + const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; + uint8_t *dst = vtcm_scratch0; + const uint8_t *src = w + nc * row_stride; + const size_t n_rows = n_blk_sz; + const size_t src_stride = row_stride; + const size_t dst_stride = sub_row_stride; + const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); + const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); + const size_t scale_off = full_qrow + blk_start * scale_blk_size; + const size_t scale_width = nb_sub * scale_blk_size; + + // 2D DMA: quants sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); + // 2D DMA: scales sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); + } + TIMER_STOP(fetch); + + TIMER_START(act_load); + // load activation block + { + dma_queue_pop(ctx->dma[0]); // wait for act DNA + transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); + } + TIMER_STOP(act_load); + + TIMER_START(wt_dequant); + // dequantize weight block + { + dma_queue_pop(ctx->dma[0]); + dma_queue_pop(ctx->dma[0]); + // vtcm_scratch0 is used to store the qweight chunk + // worker_pool_run_func already returned, so fetch is done + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, + n_blk_sz, k_blk_sz, sub_row_stride, weight_type); + } + TIMER_STOP(wt_dequant); + + // core mma + TIMER_START(core); + { + core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, + n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); } + TIMER_STOP(core); + } + + // store output block + { + float *output_block = out + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); } } } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - TIMER_STOP(total); - #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), - params->m, params->k, params->n, group_size); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", + TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); #endif - - return 0; + return 0; } -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const __fp16 *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride) { +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const uint8_t *restrict permuted_weight, int m, int k, int n, + int weight_type) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; + return -1; + } + + // for large m, k (e.g. prefill FFN Down), use out-stationary version + if (m >= 128 && k > n && n > 1024) { + int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + if (rc != FALLBACK_TO_STANDARD) { + return rc; // 0 success, -1 error + } + FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); + // fall through to standard path + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; } + FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); + // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. + // Only pays off when the chunker yields >=2 n-chunks, so the main loop can + // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for + // double-buffered output and the worker-dispatch overhead are pure loss. + // Try pipeline costs first; fall back to sequential if the layout collapses + // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. + const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) + const size_t seq_per_mn = sizeof(__fp16); // O x 1 size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - m, n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; + bool use_pipeline = false; + + if (m >= 128) { + size_t mc = 0, nc = 0, used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && + hmx_ceil_div((size_t) n, nc) >= 2) { + m_chunk_n_rows = mc; + n_chunk_n_cols = nc; + vtcm_used = used; + use_pipeline = true; + } } - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + if (!use_pipeline) { + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + } + + // Compute precise buffer sizes per execution path + const size_t weight_area_size = hex_align_up( + n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + const size_t output_area_size = hex_align_up( + m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + if (use_pipeline) { + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 + } else { + scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 + scratch1_size = scratch0_size; // x4x2 DMA buf 1 + scratch2_size = 0; // unused + } - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); @@ -1104,8 +1138,9 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, weight_type, use_pipeline, + m_chunk_n_rows, n_chunk_n_cols, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); TIMER_DEFINE(activation_load); @@ -1116,214 +1151,9 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co TIMER_DEFINE(total); TIMER_START(total); - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); - } - } - TIMER_STOP(activation_load); - - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready - - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } - - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k); - - swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); - - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } - - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - TIMER_STOP(total); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); - } -#endif - - return 0; -} - -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, - int k, int n, int w_type); - -#define FALLBACK_TO_STANDARD 1 - -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); - const bool use_pipeline = (m >= 128) && (k <= n); - - // Select cost parameters based on execution path - size_t per_n_cost, per_mn_cost; - if (use_pipeline) { - per_n_cost = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - per_mn_cost = 2 * sizeof(__fp16); // O x 2 (output double buffer) - } else { - per_n_cost = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - per_mn_cost = sizeof(__fp16); // O x 1 - } - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // Quantized weight: dequant ~1.5x more expensive per element than activation load. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", - __func__, m, k, n, use_pipeline, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); - void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - - TIMER_DEFINE(total); - TIMER_START(total); - - FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", + use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); if (!use_pipeline) { HAP_compute_res_hmx_lock(ctx->vtcm_rctx); @@ -1368,7 +1198,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); - swap_ptr(&buf_curr, &buf_next); + hex_swap_ptr(&buf_curr, &buf_next); } TIMER_STOP(weight_load); @@ -1511,300 +1341,417 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds return 0; } -// C += AB -void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, - int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); +// - Q6_bias_mxmem2_A((void *)col_scales); +static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} - const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t i = 0; i < n_row_tiles; ++i) { - const __fp16 *row_base = a + i * dot_tile_stride; - __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t j = 0; j < n_col_tiles; ++j) { - Q6_mxclracc_hf(); +static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} - const __fp16 *col_tiles = b + j * dot_tile_stride; - const __fp16 *row_tiles = row_base; - __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; - if (!zero_init) { - Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); - } +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); } -static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, - int k_block, int k_stride) { - for (int r = 0; r < n_rows; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} - const bool next_row_valid = (r + 1) < n_rows; +static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mat_mul_permuted_w16a32(ctx, + hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); + } + } + return ret; +} - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); - const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { + return -1; + } - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + const int group_size = hmx_matmul_batch_r2(params); - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } + if (group_size <= 1) { + FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } -} -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - int k_stride; -} activation_transfer_task_state_t; + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); -static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - // one chunk: one row - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } - __fp16 *dst = st->dst + chunk_idx * st->k_block; - const float *src = st->src + chunk_idx * st->k_stride; - transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } -} -void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { - assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); - assert(VLEN == 32 * sizeof(float)); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - activation_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.src = src; - state.k_block = k_block; - state.k_stride = k_stride; + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); -} + TIMER_START(total); -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, + 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + } + TIMER_STOP(output_store); + } + } + } + } } - const size_t vtcm_budget = ctx->vtcm_size; + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - const size_t K_BLOCK_SIZE = 1024; + TIMER_STOP(total); - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); +#endif + + return 0; +} + +// + +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (act_stride < k || weight_stride < k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; } - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { + // DMA-based activation gather for strided tensors (see batched path comment). + const bool use_dma_activation = (act_stride > k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, + /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/sizeof(__fp16), // O + m, n, + /*m_block_cost=*/(size_t) n, + /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); + // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); return -1; } - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); + TIMER_DEFINE(total); + TIMER_START(total); HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * act_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) k * sizeof(float); + const size_t stride_bytes = (size_t) act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_activation, + vtcm_f32_act, n_rows, k, k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_activation, + activation_chunk, n_rows, k, act_stride); + } + } + TIMER_STOP(activation_load); - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } + // issue async DMA for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. + // The source rows can be strided (e.g. KV-cache K after ggml_permute). + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - qweight_fetch_task_state_t s; - - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = - (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - - s.dst = vtcm_scratch0; - s.src = w + nc * row_stride; - s.n_rows = n_blk_sz; - s.src_stride = row_stride; - s.dst_stride = sub_row_stride; - s.quant_off = - (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - s.quant_width = - (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - s.scale_off = full_qrow + blk_start * scale_blk_size; - s.scale_width = nb_sub * scale_blk_size; + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), - s.dst_stride, s.src_stride, s.quant_width, s.n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), - s.dst_stride, s.src_stride, s.scale_width, s.n_rows); - } - TIMER_STOP(fetch); + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); + // issue async DMA for the next weight chunk (double buffering) + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); } - TIMER_STOP(core); + + // interleave row-major fp16 from scratch into tile-major in vtcm_weight + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + + hex_swap_ptr(&buf_curr, &buf_next); } + TIMER_STOP(weight_load); - // store output block + TIMER_START(hmx_core); { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); } + TIMER_STOP(output_store); } + } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + TIMER_STOP(total); + #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + { + size_t weight_size = (size_t)k * n * sizeof(__fp16); + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } #endif + return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index fb95d36f5a9..1c78ffadd1c 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, int m, int k, int n, int weight_type); +// HMX flash attention +int hmx_flash_attn_ext(struct htp_ops_context * octx); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index af04619cebb..68f174d6937 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -4,6 +4,9 @@ #ifndef HMX_UTILS_H #define HMX_UTILS_H +#include "hvx-base.h" + +#include #include #include @@ -12,21 +15,188 @@ #define HMX_FP16_TILE_N_ELMS 1024 #define HMX_FP16_TILE_SIZE 2048 -#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) - // Initialise aligned 256-byte area with scale vector + zero padding. -static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { - HVX_Vector *pv = (HVX_Vector *)out_scales; - *pv++ = v_scale; - *pv = Q6_V_vzero(); +static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + volatile HVX_Vector *pv = (HVX_Vector *) out_scales; + pv[0] = v_scale; + pv[1] = Q6_V_vzero(); +} + +// --- Shared scatter offsets and interleave helper --- + +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128. +// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047); +// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the +// call site to scatter into one tile (masked) or two contiguous tiles (unmasked). +static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128, + 11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128, + 22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128, +}; + +// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles. +// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used) +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_cols. +static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, + const __fp16 * restrict vtcm_src, + int n_cols, + int k, + int src_stride, + int start_row, + int end_row) { + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + // Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles. + // When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask) + // using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when + // n_k_tiles is odd) falls back to single-tile masked scatter. + const bool pair_scatter = (n_k_tiles & 1) == 0; + const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1); + const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1); + __builtin_assume(k > 0); + __builtin_assume(end_row > start_row); + + if (pair_scatter) { + // Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter. + const int c_step = 2 * HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } else { + // Fallback: scatter one K-tile per call (region 2047, masked). + const int c_step = HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } } -// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- +// Interleave row-major FP16 data into column-major tile format. +// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile]. +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_rows. +static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out, + const __fp16 * restrict src, + int n_rows, + int head_dim, + int src_stride, + int n_row_tiles, + int start_row, + int end_row) { + __builtin_assume(head_dim > 0); + const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS; + + for (int r = start_row; r < end_row; r += 2) { + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows; + + const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride); + const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL; + + // Row-pair invariants hoisted out of the c loop. + const int r0 = r / HMX_FP16_TILE_N_ROWS; + const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2; + + // tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0). + // Each c step (+= 64) advances both by 2 dim-tiles worth of fp16. + __fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS; + __fp16 * tb1 = tb0 + tile_stride_elms; + const size_t tb_step = 2 * tile_stride_elms; -static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { - uint8_t *p = *vtcm_ptr; - *vtcm_ptr += size; - return p; + if (pv_in1) { + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } + } } #endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index d0926dedd28..f6cb02951d0 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) { return x; } +static inline _Float16 hvx_vec_get_f16(HVX_Vector v) { + _Float16 __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 2, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 851482e01b2..a3e33c3b3af 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -7,7 +7,8 @@ #include "hvx-base.h" -#define hvx_splat_loop_body(dst_type, vec_store) \ +#define hvx_splat_pragma(x) _Pragma(#x) +#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ \ @@ -16,7 +17,7 @@ \ uint32_t i = 0; \ \ - _Pragma("unroll(4)") \ + hvx_splat_pragma(unroll(unroll_cnt)) \ for (; i < nvec; i++) { \ vdst[i] = src; \ } \ @@ -25,31 +26,47 @@ } \ } while(0) -static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { +static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { assert((unsigned long) dst % 128 == 0); - hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4); } -static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { - hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u); +static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4); } -static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) { hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } +static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1); +} + +static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1); +} + #define hvx_copy_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ diff --git a/ggml/src/ggml-hexagon/htp/vtcm-utils.h b/ggml/src/ggml-hexagon/htp/vtcm-utils.h new file mode 100644 index 00000000000..b129fb74e31 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/vtcm-utils.h @@ -0,0 +1,16 @@ +#ifndef VTCM_UTILS_H +#define VTCM_UTILS_H + +#include "hex-utils.h" + +#include +#include +#include + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // VTCM_UTILS_H From 28f8534532f5be51fa6bc0a27c30e0dbecc9769f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:45:46 +0300 Subject: [PATCH 238/249] ggml : bump version to 0.10.2 (ggml/1474) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index f7b6f1f334f..c97f681988b 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 10) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_PATCH 2) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From a5a8496d31ef1690ff2addc65b555916c3cf8895 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:49:06 +0300 Subject: [PATCH 239/249] ggml : remove obsoloete wgsl templates (ggml/0) --- .../ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | 107 ------ .../ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl | 323 ---------------- .../ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl | 295 --------------- .../wgsl-shaders/soft_max.tmpl.wgsl | 345 ------------------ 4 files changed, 1070 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl deleted file mode 100644 index b5e93b812fd..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ /dev/null @@ -1,107 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "i32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f32" - } - } -] - -#end(VARIANTS) - -#define(SHADER) -enable f16; - -@group(0) @binding(0) -var src: array<{{SRC_TYPE}}>; - -@group(0) @binding(1) -var dst: array<{{DST_TYPE}}>; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl deleted file mode 100644 index 03fcd548689..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +++ /dev/null @@ -1,323 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "reglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "geglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "swiglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_oai_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "swiglu_oai_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "geglu_erf_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_quick_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(REGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return max(a, 0) * b; -} -#enddecl(REGLU) - -#decl(GEGLU) -const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; -const GELU_COEF_A: {{TYPE}} = 0.044715; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); - return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; -} -#enddecl(GEGLU) - -#decl(SWIGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a / (1.0 + exp(-a)) * b; -} -#enddecl(SWIGLU) - -#decl(SWIGLU_OAI) -fn op(a: f32, b: f32) -> f32 { - let xi = min(a, params.limit); - let gi = max(min(b, params.limit), -params.limit); - var out_glu = xi / (1.0 + exp(-xi * params.alpha)); - out_glu = out_glu * (1.0 + gi); - return out_glu; -} -#enddecl(SWIGLU_OAI) - -#decl(GEGLU_ERF) -const p_erf: {{TYPE}} = 0.3275911; -const a1_erf: {{TYPE}} = 0.254829592; -const a2_erf: {{TYPE}} = -0.284496736; -const a3_erf: {{TYPE}} = 1.421413741; -const a4_erf: {{TYPE}} = -1.453152027; -const a5_erf: {{TYPE}} = 1.061405429; -const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let a_div_sqr2 = a * SQRT_2_INV; - let sign_x = sign(a_div_sqr2); - let x = abs(a_div_sqr2); - let t = 1.0 / (1.0 + p_erf * x); - let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); - let erf_approx = sign_x * y; - return 0.5 * a * (1.0 + erf_approx) * b; -} -#enddecl(GEGLU_ERF) - -#decl(GEGLU_QUICK) -const GELU_QUICK_COEF: {{TYPE}} = -1.702; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; -} -#enddecl(GEGLU_QUICK) - -#decl(NO_SPLIT) -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(0, params.ne0, params.swapped != 0); - return src0[base + offset]; -} - -fn b_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(params.ne0, 0, params.swapped != 0); - return src0[base + offset]; -} -#enddecl(NO_SPLIT) - -#decl(SPLIT) -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - return src0[base]; -} - -fn b_value(base: u32) -> {{TYPE}} { - return src1[base]; -} -#enddecl(SPLIT) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - swapped: u32, - alpha: f32, - limit: f32, -} - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; - let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - dst[i_dst] = op(a_value(i_a), b_value(i_b)); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl deleted file mode 100644 index 84dc8dbff61..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ /dev/null @@ -1,295 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f32_ff", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_ff_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f16_ff", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_ff_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(ROTATE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - dst[i_dst0] = {{TYPE}}(out0); - dst[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE) - -#decl(ROTATE_INPLACE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - src0[i_dst0] = {{TYPE}}(out0); - src0[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE_INPLACE) - -#decl(NO_FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return 1.0f; -} -#enddecl(NO_FF_FUNC) - -#decl(FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return src2[params.offset_src2 + i/2]; -} -#enddecl(FF_FUNC) - -#decl(NO_FF_BINDINGS) - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NO_FF_BINDINGS) - -#decl(NO_FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var params: Params; - -#enddecl(NO_FF_BINDINGS_INPLACE) - -#decl(FF_BINDINGS) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var dst: array<{{TYPE}}>; - -@group(0) @binding(4) -var params: Params; - -#enddecl(FF_BINDINGS) - -#decl(FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var params: Params; - -#enddecl(FF_BINDINGS_INPLACE) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_src2: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - n_threads: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - n_dims: u32, - mode: u32, - theta_scale: f32, - attn_factor: f32, - freq_scale: f32, - ext_factor: f32, - corr_dim0: f32, - corr_dim1: f32, - sections0: u32, - sections1: u32, - sections2: u32, - sections3: u32 -}; - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array; - -DECLS - -fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { - let y = (f32(i / 2) - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// returns vector of (cos_theta, sin_theta) -// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row -fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { - var mscale = params.attn_factor; - var theta = params.freq_scale * theta_extrap; - if (params.ext_factor != 0.0f) { - let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; - theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; - mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); - } - return vec2(cos(theta) * mscale, sin(theta) * mscale); -} - -fn pair_base(i0: u32, div_2: bool) -> u32 { - if (div_2) { - return i0 / 2; - } else { - return i0; - } -} - -fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { - if (is_vision) { - return params.n_dims; - } else if (is_neox || is_mrope) { - return params.n_dims / 2; - } else { - return 1; - } -} - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - // two elements per thread - if (gid.x >= params.n_threads) { - return; - } - - let is_neox = bool(params.mode & 2); - let is_mrope = bool(params.mode & 8); - let is_imrope = params.mode == 40; - let is_vision = params.mode == 24; - - var i = gid.x * 2; // start index for this thread - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - - if (i0 >= params.n_dims && !is_vision) { - let i_src = i_src_row + i0; - let i_dst = i_dst_row + i0; - rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); - return; - } - - var theta_base_mult: u32 = 0; - var theta_scale_pwr: u32 = i0 / 2; - if (is_mrope) { - let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; - let sec_w = params.sections1 + params.sections0; - let sec_e = params.sections2 + sec_w; - let sector = (i0 / 2) % sect_dims; - if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * params.sections1) { - theta_base_mult = 1; - } else if (sector % 3 == 2 && sector < 3 * params.sections2) { - theta_base_mult = 2; - } else if (sector % 3 == 0 && sector < 3 * params.sections0) { - theta_base_mult = 0; - } else { - theta_base_mult = 3; - } - } else { - if (sector >= params.sections0 && sector < sec_w) { - theta_base_mult = 1; - if (is_vision) { - theta_scale_pwr = sector - params.sections0; - } - } else if (sector >= sec_w && sector < sec_e) { - theta_base_mult = 2; - if (is_vision) { - theta_scale_pwr = sector - sec_w; - } - } else if (sector >= sec_e) { - if (is_vision) { - theta_scale_pwr = sector - sec_e; - theta_scale_pwr = (i0 / 2) % sec_e; - } - theta_base_mult = 3; - } else if (is_vision) { - theta_scale_pwr = sector; - } - } - } - let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); - let thetas = rope_yarn(theta_base/freq_factor(i0), i0); - - let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); - let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); - - let x0 = f32(src0[i_src]); - let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); - rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl deleted file mode 100644 index c74dc4cc923..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ /dev/null @@ -1,345 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_NAME": "soft_max_f32", - "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_inplace", - "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink", - "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink_inplace", - "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - } -] -#end(VARIANTS) - -#define(DECLS) - -#decl(BASE_BINDINGS) -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(BASE_BINDINGS) - -#decl(BASE_BINDINGS_INPLACE) -@group(0) @binding(1) -var params: Params; -#enddecl(BASE_BINDINGS_INPLACE) - -#decl(SINK_BINDINGS) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(SINK_BINDINGS) - -#decl(SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(SINK_BINDINGS_INPLACE) - -#decl(MASK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_BINDINGS) - -#decl(MASK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var params: Params; -#enddecl(MASK_BINDINGS_INPLACE) - -#decl(MASK_SINK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var dst: array; - -@group(0) @binding(4) -var params: Params; -#enddecl(MASK_SINK_BINDINGS) - -#decl(MASK_SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_SINK_BINDINGS_INPLACE) - -#decl(NOT_INPLACE) -fn inter_value(i: u32) -> f32 { - return dst[i]; -} - -fn update(i: u32, val: f32) { - dst[i] = val; -} -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -fn inter_value(i: u32) -> f32 { - return src[i]; -} - -fn update(i: u32, val: f32) { - src[i] = val; -} -#enddecl(INPLACE) - -#decl(NO_MASK) -fn mask_val(i: u32) -> f32 { - return 0.0; -} -#enddecl(NO_MASK) - -#decl(MASK) -fn mask_val(i: u32) -> f32 { - return f32(mask[i]); -} -#enddecl(MASK) - -#decl(NO_SINK) -fn lower_max_bound(i2: u32) -> f32 { - return -1e30; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val; -} -#enddecl(NO_SINK) - -#decl(SINK) -fn lower_max_bound(i2: u32) -> f32 { - return sinks[params.offset_sinks + i2]; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val + exp(sinks[params.offset_sinks + i2] - max_val); -} -#enddecl(SINK) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_sinks: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of src0/dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - // shape of src1 - ne12: u32, - ne13: u32, - - scale: f32, - max_bias: f32, - n_head_log2: f32, - m0: f32, - m1: f32, -}; - -@group(0) @binding(0) -var src: array; - -DECLS - -const CACHE_SIZE: u32 = 16; - -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) -fn main(@builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - - var i = wid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; - let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; - - let head = f32(i2); - let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); - - var cache: array; - - var max_val = lower_max_bound(i2); - var col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); - max_val = max(max_val, val); - if (col < CACHE_SIZE) { - cache[col] = val; - } - col += wg_size; - } - - scratch[lid.x] = max_val; - workgroupBarrier(); - var offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); - } - offset = offset / 2; - workgroupBarrier(); - } - let row_max = scratch[0]; - workgroupBarrier(); - - var sum = 0.0f; - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), - cache[col], col < CACHE_SIZE); - let ex = exp(val - row_max); - sum += ex; - if (col < CACHE_SIZE) { - cache[col] = ex; - } else { - update(i_dst_row + col, ex); - } - col += wg_size; - } - - scratch[lid.x] = sum; - workgroupBarrier(); - offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] += scratch[lid.x + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - let row_sum = add_sinks(scratch[0], i2, row_max); - - let sum_recip = 1.0 / row_sum; - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); - col += wg_size; - } -} -#end(SHADER) From bbdaa21aa7d301675f5cf7fd87f8c0b8c272dd29 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:51:39 +0300 Subject: [PATCH 240/249] ggml : remove obsolete rms_norm.wgsl (ggml/0) --- .../ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 123 ------------------ 1 file changed, 123 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl deleted file mode 100644 index 712b921f1ab..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ /dev/null @@ -1,123 +0,0 @@ -#define(VARIANTS) - -[ - { - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_SUFFIX": "inplace", - "DECLS": ["INPLACE"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) - -struct Params { - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of src/dst - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: f32 -}; - -@group(0) @binding(0) -var src: array; - -DECLS - -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) -fn main(@builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - - // one thread per row - var i = wid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - - let elems = (params.ne0 + wg_size - 1) / wg_size; - - var sum = 0.0f; - var col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - sum += pow(src[i_src_row + col], 2.0); - col += wg_size; - } - - scratch[lid.x] = sum; - workgroupBarrier(); - var offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] += scratch[lid.x + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - sum = scratch[0]; - - let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - update(i_src_row + col, i_dst_row + col, scale); - col += wg_size; - } -} -#end(SHADER) From 8384aa8086714d6177f24eb5c409b39949efd2ce Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:53:58 +0300 Subject: [PATCH 241/249] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index a03455e74c8..812e721a8c5 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b70770970e84c30a007b3859a453768b3ece2d3d +19eac6f0edaf285506eb6228d31bb9caeda9aba1 From 18162bcf6120551cfd447c81a09a98c6ed3db675 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:54:20 +0300 Subject: [PATCH 242/249] cmake : add FindNCCL.cmake (ggml/0) --- ggml/cmake/FindNCCL.cmake | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 ggml/cmake/FindNCCL.cmake diff --git a/ggml/cmake/FindNCCL.cmake b/ggml/cmake/FindNCCL.cmake new file mode 100644 index 00000000000..67511e2d56a --- /dev/null +++ b/ggml/cmake/FindNCCL.cmake @@ -0,0 +1,36 @@ +# cmake/FindNCCL.cmake + +# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead. + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES include +) + +find_library(NCCL_LIBRARY + NAMES nccl + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES lib lib64 +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL + DEFAULT_MSG + NCCL_LIBRARY NCCL_INCLUDE_DIR +) + +if(NCCL_FOUND) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL UNKNOWN IMPORTED) + set_target_properties(NCCL::NCCL PROPERTIES + IMPORTED_LOCATION "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}" + ) + endif() +endif() + +mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY) From 4bf733672b2871d4153158af4f621a6dd9104f4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 09:01:24 +0300 Subject: [PATCH 243/249] talk-llama : sync llama.cpp --- examples/talk-llama/llama-adapter.cpp | 17 +- examples/talk-llama/llama-adapter.h | 4 +- examples/talk-llama/llama-arch.cpp | 2099 +---------------- examples/talk-llama/llama-arch.h | 20 +- examples/talk-llama/llama-batch.h | 2 +- examples/talk-llama/llama-chat.cpp | 51 +- examples/talk-llama/llama-chat.h | 5 +- examples/talk-llama/llama-context.cpp | 224 +- examples/talk-llama/llama-context.h | 14 +- examples/talk-llama/llama-ext.h | 84 +- examples/talk-llama/llama-grammar.cpp | 74 +- examples/talk-llama/llama-graph.cpp | 259 +- examples/talk-llama/llama-graph.h | 33 + examples/talk-llama/llama-hparams.h | 4 + examples/talk-llama/llama-impl.cpp | 2 +- examples/talk-llama/llama-kv-cache.cpp | 223 +- examples/talk-llama/llama-kv-cache.h | 33 +- .../talk-llama/llama-memory-hybrid-iswa.cpp | 6 +- examples/talk-llama/llama-memory-hybrid.cpp | 6 +- .../talk-llama/llama-memory-recurrent.cpp | 12 +- examples/talk-llama/llama-mmap.cpp | 35 +- examples/talk-llama/llama-mmap.h | 1 + examples/talk-llama/llama-model-loader.cpp | 46 +- examples/talk-llama/llama-model-loader.h | 1 + examples/talk-llama/llama-model-saver.cpp | 127 +- examples/talk-llama/llama-model-saver.h | 4 + examples/talk-llama/llama-model.cpp | 1489 +++++++----- examples/talk-llama/llama-model.h | 63 +- examples/talk-llama/llama-quant.cpp | 220 +- examples/talk-llama/llama-vocab.cpp | 148 +- examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/llama.cpp | 906 ++----- examples/talk-llama/llama.h | 82 +- examples/talk-llama/models/afmoe.cpp | 19 +- examples/talk-llama/models/apertus.cpp | 18 +- examples/talk-llama/models/arcee.cpp | 29 +- examples/talk-llama/models/arctic.cpp | 16 +- examples/talk-llama/models/baichuan.cpp | 17 +- examples/talk-llama/models/bailingmoe.cpp | 28 +- examples/talk-llama/models/bailingmoe2.cpp | 14 +- examples/talk-llama/models/bert.cpp | 38 +- examples/talk-llama/models/bitnet.cpp | 38 +- examples/talk-llama/models/bloom.cpp | 22 +- examples/talk-llama/models/chameleon.cpp | 28 +- examples/talk-llama/models/chatglm.cpp | 45 +- examples/talk-llama/models/codeshell.cpp | 14 +- examples/talk-llama/models/cogvlm.cpp | 10 +- examples/talk-llama/models/cohere2-iswa.cpp | 28 +- examples/talk-llama/models/command-r.cpp | 25 +- examples/talk-llama/models/dbrx.cpp | 18 +- examples/talk-llama/models/deci.cpp | 27 +- examples/talk-llama/models/deepseek.cpp | 25 +- examples/talk-llama/models/deepseek2.cpp | 40 +- examples/talk-llama/models/dots1.cpp | 16 +- examples/talk-llama/models/dream.cpp | 22 +- examples/talk-llama/models/ernie4-5-moe.cpp | 25 +- examples/talk-llama/models/ernie4-5.cpp | 25 +- examples/talk-llama/models/eurobert.cpp | 16 +- examples/talk-llama/models/exaone-moe.cpp | 16 +- examples/talk-llama/models/exaone.cpp | 27 +- examples/talk-llama/models/exaone4.cpp | 17 +- examples/talk-llama/models/falcon-h1.cpp | 17 +- examples/talk-llama/models/falcon.cpp | 12 +- .../talk-llama/models/gemma-embedding.cpp | 18 +- examples/talk-llama/models/gemma.cpp | 17 +- examples/talk-llama/models/gemma2-iswa.cpp | 16 +- examples/talk-llama/models/gemma3.cpp | 18 +- examples/talk-llama/models/gemma3n-iswa.cpp | 98 +- examples/talk-llama/models/gemma4-iswa.cpp | 322 +++ examples/talk-llama/models/glm4-moe.cpp | 25 +- examples/talk-llama/models/glm4.cpp | 41 +- examples/talk-llama/models/gpt2.cpp | 18 +- examples/talk-llama/models/gptneox.cpp | 15 +- examples/talk-llama/models/granite-hybrid.cpp | 28 +- examples/talk-llama/models/granite.cpp | 29 +- examples/talk-llama/models/grok.cpp | 25 +- examples/talk-llama/models/grovemoe.cpp | 16 +- examples/talk-llama/models/hunyuan-dense.cpp | 66 +- examples/talk-llama/models/hunyuan-moe.cpp | 25 +- examples/talk-llama/models/internlm2.cpp | 25 +- examples/talk-llama/models/jais.cpp | 28 +- examples/talk-llama/models/jais2.cpp | 23 +- examples/talk-llama/models/jamba.cpp | 19 +- examples/talk-llama/models/kimi-linear.cpp | 5 +- examples/talk-llama/models/lfm2.cpp | 17 +- examples/talk-llama/models/llada-moe.cpp | 16 +- examples/talk-llama/models/llada.cpp | 15 +- examples/talk-llama/models/llama.cpp | 28 +- .../models/{llama-iswa.cpp => llama4.cpp} | 41 +- examples/talk-llama/models/maincoder.cpp | 16 +- examples/talk-llama/models/mamba-base.cpp | 8 +- examples/talk-llama/models/mimo2-iswa.cpp | 2 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 25 +- examples/talk-llama/models/models.h | 36 +- examples/talk-llama/models/modern-bert.cpp | 17 +- examples/talk-llama/models/mpt.cpp | 26 +- examples/talk-llama/models/nemotron-h.cpp | 47 +- examples/talk-llama/models/nemotron.cpp | 25 +- examples/talk-llama/models/neo-bert.cpp | 16 +- examples/talk-llama/models/olmo.cpp | 25 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 2 +- .../talk-llama/models/openai-moe-iswa.cpp | 25 +- examples/talk-llama/models/openelm.cpp | 2 +- examples/talk-llama/models/orion.cpp | 28 +- examples/talk-llama/models/paddleocr.cpp | 25 +- examples/talk-llama/models/pangu-embedded.cpp | 20 +- examples/talk-llama/models/phi2.cpp | 29 +- examples/talk-llama/models/phi3.cpp | 26 +- examples/talk-llama/models/plamo.cpp | 16 +- examples/talk-llama/models/plamo2.cpp | 3 +- examples/talk-llama/models/plamo3.cpp | 4 +- examples/talk-llama/models/plm.cpp | 2 +- examples/talk-llama/models/qwen.cpp | 14 +- examples/talk-llama/models/qwen2.cpp | 28 +- examples/talk-llama/models/qwen2moe.cpp | 25 +- examples/talk-llama/models/qwen2vl.cpp | 19 +- examples/talk-llama/models/qwen3.cpp | 19 +- examples/talk-llama/models/qwen35.cpp | 12 +- examples/talk-llama/models/qwen35moe.cpp | 12 +- examples/talk-llama/models/qwen3moe.cpp | 19 +- examples/talk-llama/models/qwen3next.cpp | 25 +- examples/talk-llama/models/qwen3vl-moe.cpp | 16 +- examples/talk-llama/models/qwen3vl.cpp | 16 +- examples/talk-llama/models/refact.cpp | 16 +- examples/talk-llama/models/rnd1.cpp | 16 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 25 +- examples/talk-llama/models/smallthinker.cpp | 17 +- examples/talk-llama/models/smollm3.cpp | 25 +- examples/talk-llama/models/stablelm.cpp | 28 +- examples/talk-llama/models/starcoder.cpp | 18 +- examples/talk-llama/models/starcoder2.cpp | 25 +- examples/talk-llama/models/step35-iswa.cpp | 6 +- examples/talk-llama/models/t5-enc.cpp | 96 - .../talk-llama/models/{t5-dec.cpp => t5.cpp} | 116 +- examples/talk-llama/models/t5encoder.cpp | 3 + .../talk-llama/models/wavtokenizer-dec.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 16 +- examples/talk-llama/unicode.cpp | 178 +- examples/talk-llama/unicode.h | 2 +- 144 files changed, 3675 insertions(+), 5535 deletions(-) create mode 100644 examples/talk-llama/models/gemma4-iswa.cpp rename examples/talk-llama/models/{llama-iswa.cpp => llama4.cpp} (81%) delete mode 100644 examples/talk-llama/models/t5-enc.cpp rename examples/talk-llama/models/{t5-dec.cpp => t5.cpp} (64%) create mode 100644 examples/talk-llama/models/t5encoder.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index d6a5800e63a..4a1aaa955a8 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -294,7 +294,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } // get extra buffer types of the CPU - // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // TODO: a more general solution for non-CPU extra buft should be implemented in the future // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 std::vector buft_extra; { @@ -418,7 +418,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(); + llama_adapter_lora * adapter = new llama_adapter_lora(model); try { llama_adapter_lora_init_impl(*model, path_lora, *adapter); @@ -471,8 +471,17 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora *) { - // deprecated: adapters are freed by llama_model's destructor +void llama_adapter_lora_free(llama_adapter_lora * adapter) { + if (adapter == nullptr) { + return; + } + + if (adapter->model != nullptr) { + adapter->model->loras.erase(adapter); + adapter->model = nullptr; + } + + delete adapter; } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index aa3ab63ad75..f0b1e50f816 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -61,6 +61,8 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { + llama_model * model = nullptr; + // map tensor name to lora_a_b std::unordered_map ab_map; @@ -75,7 +77,7 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector alora_invocation_tokens; - llama_adapter_lora() = default; + explicit llama_adapter_lora(llama_model * model) : model(model) {} ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 799d16167ba..633a66fc665 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -56,6 +56,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA4, "gemma4" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -73,6 +74,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -107,6 +109,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, + { LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -123,6 +126,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, @@ -163,6 +167,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, + { LLM_KV_EMBEDDING_LENGTH_PER_LAYER, "%s.embedding_length_per_layer_input" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -236,6 +241,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, + { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, @@ -245,6 +251,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" }, { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -362,6 +369,9 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_POST_NORM_1, "blk.%d.post_ffw_norm_1" }, + { LLM_TENSOR_FFN_POST_NORM_2, "blk.%d.post_ffw_norm_2" }, + { LLM_TENSOR_FFN_PRE_NORM_2, "blk.%d.pre_ffw_norm_2" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, @@ -371,6 +381,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, @@ -538,2016 +549,6 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; -static std::set llm_get_tensor_names(llm_arch arch) { - switch (arch) { - case LLM_ARCH_CLIP: - return {}; - case LLM_ARCH_LLAMA: - case LLM_ARCH_DECI: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_ARCEE: - case LLM_ARCH_STARCODER2: - case LLM_ARCH_NEMOTRON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_AFMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_LLAMA4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAICHUAN: - case LLM_ARCH_ORION: - case LLM_ARCH_XVERSE: - case LLM_ARCH_EXAONE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_FALCON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_ATTN_OUT_NORM, - }; - case LLM_ARCH_GPT2: - case LLM_ARCH_STARCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GPTNEOX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MPT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_ACT, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_REFACT: - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_INTERNLM2: - case LLM_ARCH_GRANITE: - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_PADDLEOCR: - case LLM_ARCH_SMOLLM3: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_PANGU_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_NOMIC_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NOMIC_BERT_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_NEO_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_EUROBERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_MODERN_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_CLS_NORM, - }; - case LLM_ARCH_JINA_BERT_V2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - }; - case LLM_ARCH_JINA_BERT_V3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_BLOOM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_STABLELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_QWEN: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN2MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_QWEN3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_OLMOE: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_QWEN3NEXT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA_ALPHA, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN35: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_ALPHA, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN35MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_ALPHA, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN3VL: - case LLM_ARCH_CHAMELEON: - case LLM_ARCH_HUNYUAN_DENSE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHIMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_PLAMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PLAMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_PLAMO3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CODESHELL: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MINICPM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - }; - case LLM_ARCH_MINICPM3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GEMMA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GEMMA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3N: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_PER_LAYER_TOKEN_EMBD, - LLM_TENSOR_PER_LAYER_MODEL_PROJ, - LLM_TENSOR_PER_LAYER_PROJ_NORM, - LLM_TENSOR_ALTUP_UNEMBD_PROJ, - LLM_TENSOR_ALTUP_PROJ, - LLM_TENSOR_PER_LAYER_INP_GATE, - LLM_TENSOR_PER_LAYER_PROJ, - LLM_TENSOR_PER_LAYER_POST_NORM, - LLM_TENSOR_ALTUP_CORRECT_COEF, - LLM_TENSOR_ALTUP_CORRECT_SCALE, - LLM_TENSOR_ALTUP_PREDICT_COEF, - LLM_TENSOR_ALTUP_ROUTER, - LLM_TENSOR_ALTUP_ROUTER_NORM, - LLM_TENSOR_LAUREL_L, - LLM_TENSOR_LAUREL_R, - LLM_TENSOR_LAUREL_POST_NORM, - }; - case LLM_ARCH_GEMMA_EMBEDDING: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - LLM_TENSOR_DENSE_3_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_MAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_MAMBA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_JAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_FALCON_H1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_COMMAND_R: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_COHERE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_DBRX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OLMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OLMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OPENELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_ARCTIC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM_EXPS, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_DEEPSEEK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_DEEPSEEK2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_PLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CHATGLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GLM4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_GLM4_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_GLM_DSA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_INDEXER_K_NORM, - LLM_TENSOR_INDEXER_PROJ, - LLM_TENSOR_INDEXER_ATTN_K, - LLM_TENSOR_INDEXER_ATTN_Q_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_BITNET: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_SUB_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_SUB_NORM, - }; - case LLM_ARCH_T5: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DEC_OUTPUT_NORM, - LLM_TENSOR_DEC_ATTN_NORM, - LLM_TENSOR_DEC_ATTN_Q, - LLM_TENSOR_DEC_ATTN_K, - LLM_TENSOR_DEC_ATTN_V, - LLM_TENSOR_DEC_ATTN_OUT, - LLM_TENSOR_DEC_ATTN_REL_B, - LLM_TENSOR_DEC_CROSS_ATTN_NORM, - LLM_TENSOR_DEC_CROSS_ATTN_Q, - LLM_TENSOR_DEC_CROSS_ATTN_K, - LLM_TENSOR_DEC_CROSS_ATTN_V, - LLM_TENSOR_DEC_CROSS_ATTN_OUT, - LLM_TENSOR_DEC_CROSS_ATTN_REL_B, - LLM_TENSOR_DEC_FFN_NORM, - LLM_TENSOR_DEC_FFN_GATE, - LLM_TENSOR_DEC_FFN_DOWN, - LLM_TENSOR_DEC_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_T5ENCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_JAIS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_JAIS2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_NEMOTRON_H: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NEMOTRON_H_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - // mamba(2) ssm layers - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - // attention layers - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - // dense FFN - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (for MoE layers) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_LATENT_DOWN, - LLM_TENSOR_FFN_LATENT_UP, - // MoE shared expert layer - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_EXAONE4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_EXAONE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_RWKV6: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_W, - LLM_TENSOR_TIME_MIX_LERP_K, - LLM_TENSOR_TIME_MIX_LERP_V, - LLM_TENSOR_TIME_MIX_LERP_R, - LLM_TENSOR_TIME_MIX_LERP_G, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_LERP_R, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, - }; - case LLM_ARCH_RWKV6QWEN2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_RWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - }; - case LLM_ARCH_ARWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GRANITE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_GRANITE_HYBRID: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_WAVTOKENIZER_DEC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_CONV1D, - LLM_TENSOR_CONVNEXT_DW, - LLM_TENSOR_CONVNEXT_NORM, - LLM_TENSOR_CONVNEXT_PW1, - LLM_TENSOR_CONVNEXT_PW2, - LLM_TENSOR_CONVNEXT_GAMMA, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_POS_NET_CONV1, - LLM_TENSOR_POS_NET_CONV2, - LLM_TENSOR_POS_NET_NORM, - LLM_TENSOR_POS_NET_NORM1, - LLM_TENSOR_POS_NET_NORM2, - LLM_TENSOR_POS_NET_ATTN_NORM, - LLM_TENSOR_POS_NET_ATTN_Q, - LLM_TENSOR_POS_NET_ATTN_K, - LLM_TENSOR_POS_NET_ATTN_V, - LLM_TENSOR_POS_NET_ATTN_OUT, - }; - case LLM_ARCH_BAILINGMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAILINGMOE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_DOTS1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_ERNIE4_5_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_HUNYUAN_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OPENAI_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_LFM2: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - }; - case LLM_ARCH_LFM2MOE: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_SMALLTHINKER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_APERTUS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_SEED_OSS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROVEMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_CHEXPS, - LLM_TENSOR_FFN_DOWN_CHEXPS, - LLM_TENSOR_FFN_UP_CHEXPS, - }; - case LLM_ARCH_MINIMAX_M2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_COGVLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_VISEXP_ATTN_QKV, - LLM_TENSOR_VISEXP_ATTN_OUT, - LLM_TENSOR_VISEXP_FFN_GATE, - LLM_TENSOR_VISEXP_FFN_DOWN, - LLM_TENSOR_VISEXP_FFN_UP, - }; - case LLM_ARCH_MIMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_STEP35: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_GPTJ: - case LLM_ARCH_UNKNOWN: - return { - LLM_TENSOR_TOKEN_EMBD, - }; - case LLM_ARCH_MAINCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_KIMI_LINEAR: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - // Dense FFN (layer 0 only) - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (layers 1+) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - // Shared experts - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) - LLM_TENSOR_SSM_CONV1D_Q, - LLM_TENSOR_SSM_CONV1D_K, - LLM_TENSOR_SSM_CONV1D_V, - LLM_TENSOR_SSM_F_A, - LLM_TENSOR_SSM_F_B, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_G_A, - LLM_TENSOR_SSM_G_B, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_NORM, - // MLA - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_KV_A_NORM, - }; - default: - GGML_ABORT("unknown architecture for tensor mapping"); - } -} - // declare information about the model weight tensors: // - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight // - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator @@ -2559,20 +560,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { // example: https://github.com/ggml-org/llama.cpp/pull/17548 // static const std::map LLM_TENSOR_INFOS = { - {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer) + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -2680,11 +681,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_PRE_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2705,9 +710,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) - {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2723,7 +728,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, - {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2778,18 +783,13 @@ std::string LLM_KV::operator()(llm_kv kv) const { } LLM_TN_IMPL::LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid) - : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid), - model_tensors(llm_get_tensor_names(arch)) {} + : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid) {} std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.find(tensor) == LLM_TENSOR_NAMES.end()) { GGML_ABORT("unknown tensor name for tensor id %d", static_cast(tensor)); } - if (model_tensors.find(tensor) == model_tensors.end()) { - return LLM_TENSOR_NAMES.at(tensor); - } - std::string name = ::format(LLM_TENSOR_NAMES.at(tensor), bid, xid); if (suffix != nullptr) { name += "."; @@ -2875,3 +875,34 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { return false; } } + +bool llm_arch_supports_sm_tensor(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_GROK: + case LLM_ARCH_MPT: + case LLM_ARCH_PLAMO2: + case LLM_ARCH_MINICPM3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: + case LLM_ARCH_JAMBA: + case LLM_ARCH_FALCON_H1: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_GRANITE_HYBRID: + case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: + case LLM_ARCH_MINIMAX_M2: + case LLM_ARCH_MISTRAL4: + case LLM_ARCH_KIMI_LINEAR: + return false; + default: + return true; + } +} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index b1b1dcf1883..8f335f5c7b3 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -60,6 +60,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA4, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -77,6 +78,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_DEEPSEEK2OCR, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -111,6 +113,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, + LLM_ARCH_HUNYUAN_VL, LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, @@ -127,6 +130,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MISTRAL4, LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, LLM_ARCH_STEP35, @@ -167,6 +171,7 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_EMBEDDING_LENGTH_OUT, + LLM_KV_EMBEDDING_LENGTH_PER_LAYER, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -240,6 +245,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, + LLM_KV_ATTENTION_SHARED_KV_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, @@ -249,6 +255,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ALPHA, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -367,6 +374,9 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_POST_NORM_1, + LLM_TENSOR_FFN_POST_NORM_2, + LLM_TENSOR_FFN_PRE_NORM_2, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, @@ -391,6 +401,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_LAYER_OUT_SCALE, LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n @@ -576,8 +587,6 @@ struct LLM_TN_IMPL { const int bid; const int xid; - const std::set model_tensors; - LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid); std::string str() const; @@ -623,6 +632,7 @@ llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); -bool llm_arch_is_recurrent(const llm_arch & arch); -bool llm_arch_is_hybrid (const llm_arch & arch); -bool llm_arch_is_diffusion(const llm_arch & arch); +bool llm_arch_is_recurrent (const llm_arch & arch); +bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion (const llm_arch & arch); +bool llm_arch_supports_sm_tensor(const llm_arch & arch); diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 8e6fac0efab..f77520e86c3 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -18,7 +18,7 @@ struct llama_ubatch { } // typical for M-RoPE cases: - // 0 - sequantial position of the tokens/embeddings in the sequence + // 0 - sequential position of the tokens/embeddings in the sequence // 1 - y position in the image // 2 - x position in the image // 3 - other diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index c415a998f33..6554a89b28a 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -49,6 +49,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "deepseek-ocr", LLM_CHAT_TEMPLATE_DEEPSEEK_OCR }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, @@ -59,7 +60,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, - { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, + { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -71,6 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, + { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -190,7 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { - return LLM_CHAT_TEMPLATE_GRANITE; + if (tmpl_contains("") || tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { @@ -211,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -548,6 +556,11 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << LU8("<|Assistant|>"); } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_OCR) { + for (auto message : chat) { + // no template + ss << message->content; + } } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct @@ -611,8 +624,8 @@ int32_t llm_chat_apply_template( ss << "Assistant: " << trim(chat[i]->content) << "\n\n"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { - // IBM Granite template + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_3_X) { + // IBM Granite 3.x template for (const auto & message : chat) { std::string role(message->role); ss << "<|start_of_role|>" << role << "<|end_of_role|>"; @@ -624,6 +637,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_0) { + // IBM Granite 4.0 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; @@ -798,6 +825,22 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { + // tencent/HunyuanOCR + ss << "<|hy_begin▁of▁sentence|>"; + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0 && role == "system") { + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + continue; + } + + if (role == "user") { + ss << chat[i]->content << "<|hy_User|>"; + } else if (role == "assistant") { + ss << chat[i]->content << "<|hy_Assistant|>"; + } + } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 9ed1db128ec..13f936a946c 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -28,6 +28,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK_2, LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_DEEPSEEK_OCR, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGLM_3, @@ -38,7 +39,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_EXAONE_4, LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, - LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GRANITE_3_X, + LLM_CHAT_TEMPLATE_GRANITE_4_0, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, @@ -51,6 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, + LLM_CHAT_TEMPLATE_HUNYUAN_OCR, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 1f7a52d7895..8126249e143 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1,5 +1,6 @@ #include "llama-context.h" +#include "ggml.h" #include "llama-arch.h" #include "llama-impl.h" #include "llama-batch.h" @@ -8,6 +9,7 @@ #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" +#include "llama.h" #include #include @@ -217,10 +219,10 @@ llama_context::llama_context( if (!hparams.vocab_only) { // GPU backends - for (auto * dev : model.devices) { - ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + for (const auto & dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr); if (backend == nullptr) { - throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev))); } backends.emplace_back(backend); } @@ -295,8 +297,8 @@ llama_context::llama_context( if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { // use the host buffer of the first device CPU for faster transfer of the intermediate state - auto * dev = model.devices[0]; - auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + const auto & dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev.dev); if (host_buft) { buft = host_buft; } @@ -342,14 +344,6 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); - - if (!graph_reuse_disable) { - // TODO: figure out a way to make graph reuse work with pipeline parallelism - // ref: https://github.com/ggml-org/llama.cpp/pull/20463 - LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); - - graph_reuse_disable = true; - } } sched_reserve(); @@ -594,7 +588,7 @@ void llama_context::sched_reserve() { // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // TODO: not sure if the following graph would be worst case for multi-stream KV caches: // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // @@ -1028,9 +1022,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void for (auto & backend : backends) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { - set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + if (reg) { + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + } } } } @@ -1165,9 +1161,11 @@ bool llama_context::set_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - // TODO: should we reserve? + bool res = cvec->apply(model, data, len, n_embd, il_start, il_end); - return cvec->apply(model, data, len, n_embd, il_start, il_end); + sched_need_reserve = true; + + return res; } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1187,6 +1185,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + // with pipeline parallelism, the previous graph_compute_async may still be running + // on the GPU. we must synchronize before set_inputs to avoid overwriting input tensors + // that the previous compute is still reading. + if (cparams.pipeline_parallel) { + ggml_backend_sched_synchronize(sched.get()); + } + n_reused++; } else { res->reset(); @@ -1345,8 +1350,11 @@ int llama_context::encode(const llama_batch & batch_inp) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1767,12 +1775,16 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = embd_seq; + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1944,6 +1956,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); return 0; } + ggml_backend_buffer_clear(buf_output.get(), 0); } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); @@ -2623,7 +2636,7 @@ void llama_context::perf_reset() { n_reused = 0; } -std::map llama_context::memory_breakdown() const { +llama_memory_breakdown llama_context::memory_breakdown() const { std::map ret; for (const auto & [buft, size] : model.memory_breakdown()) { ret[buft].model += size; @@ -2933,7 +2946,22 @@ llama_context * llama_init_from_model( params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { + if (model->split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + LLAMA_LOG_INFO("%s: enabling flash_attn since it is required for SPLIT_MODE_TENSOR\n", __func__); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_ENABLED) { + LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); + return nullptr; + } + if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { + LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); + return nullptr; + } + } + + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { @@ -2944,7 +2972,7 @@ llama_context * llama_init_from_model( } } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) { @@ -3465,142 +3493,6 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } -void llama_memory_breakdown_print(const struct llama_context * ctx) { - const std::vector & devices = ctx->get_model().devices; - - std::map memory_breakdown = ctx->memory_breakdown(); - - std::vector> table_data; - table_data.reserve(devices.size()); - const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; - const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; - const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; - - table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); - - constexpr size_t MiB = 1024 * 1024; - const std::vector desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; - - // track seen buffer types to avoid double counting: - std::set seen_buffer_types; - - // accumulative memory breakdown for each device and for host: - std::vector mb_dev(devices.size()); - llama_memory_breakdown_data mb_host; - - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (ggml_backend_buft_is_host(buft)) { - mb_host.model += mb.model; - mb_host.context += mb.context; - mb_host.compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (dev) { - int i_dev = -1; - for (size_t i = 0; i < devices.size(); i++) { - if (devices[i] == dev) { - i_dev = i; - break; - } - } - if (i_dev != -1) { - mb_dev[i_dev].model += mb.model; - mb_dev[i_dev].context += mb.context; - mb_dev[i_dev].compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - } - } - - // print memory breakdown for each device: - for (size_t i = 0; i < devices.size(); i++) { - ggml_backend_dev_t dev = devices[i]; - llama_memory_breakdown_data mb = mb_dev[i]; - - const std::string name = ggml_backend_dev_name(dev); - std::string desc = ggml_backend_dev_description(dev); - for (const std::string & prefix : desc_prefixes_strip) { - if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { - desc = desc.substr(prefix.length()); - } - } - - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - - const size_t self = mb.model + mb.context + mb.compute; - const size_t unaccounted = total - self - free; - - table_data.push_back({ - template_gpu, - " - " + name + " (" + desc + ")", - std::to_string(total / MiB), - std::to_string(free / MiB), - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - std::to_string(unaccounted / MiB)}); - } - - // print memory breakdown for host: - { - const size_t self = mb_host.model + mb_host.context + mb_host.compute; - table_data.push_back({ - template_other, - " - Host", - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb_host.model / MiB), - std::to_string(mb_host.context / MiB), - std::to_string(mb_host.compute / MiB), - ""}); // unaccounted - } - - // print memory breakdown for all remaining buffer types: - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (seen_buffer_types.count(buft) == 1) { - continue; - } - const std::string name = ggml_backend_buft_name(buft); - const size_t self = mb.model + mb.context + mb.compute; - table_data.push_back({ - template_other, - " - " + name, - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - ""}); // unaccounted - seen_buffer_types.insert(buft); - } - - for (size_t j = 1; j < table_data[0].size(); j++) { - size_t max_len = 0; - for (const auto & td : table_data) { - max_len = std::max(max_len, td[j].length()); - } - for (auto & td : table_data) { - td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); - } - } - for (const auto & td : table_data) { - LLAMA_LOG_INFO(td[0].c_str(), - __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), - td[6].c_str(), td[7].c_str(), td[8].c_str()); - } -} - // // training // @@ -3631,3 +3523,11 @@ void llama_opt_epoch( callback_train, callback_eval); } + +// +// ext +// + +llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { + return ctx->memory_breakdown(); +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index e0d0085c1c3..53c705eaffc 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-ext.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" @@ -22,17 +23,6 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; -// "memory" as in physical memory for a buffer type, in bytes -struct llama_memory_breakdown_data { - size_t model = 0; // memory allocated for the model - size_t context = 0; // memory allocated for the context - size_t compute = 0; // memory allocated for temporary compute buffers - - size_t total() const { - return model + context + compute; - } -}; - struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -172,7 +162,7 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); - std::map memory_breakdown() const; + llama_memory_breakdown memory_breakdown() const; // // training diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 13ced783b42..8ce29d217cb 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -1,8 +1,12 @@ #pragma once -#include "llama-context.h" -#include "ggml.h" -#include "stdint.h" +// this is a staging header for new llama.cpp API +// breaking changes and C++ are allowed. everything here should be considered WIP + +#include "llama.h" + +#include +#include // Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. LLAMA_API struct ggml_cgraph * llama_graph_reserve( @@ -10,3 +14,77 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); + +// Get the default ggml_type for a given ftype. +LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype); + +struct quantize_state_impl; + +LLAMA_API quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params); + +LLAMA_API void llama_quant_free(quantize_state_impl * qs); + +// Descriptor for constructing a mock model for quantization testing. +struct llama_quant_model_desc { + const char * architecture; + uint32_t n_embd; + uint32_t n_ff; + uint32_t n_layer; + uint32_t n_head; + uint32_t n_head_kv; + uint32_t n_expert; + uint32_t n_embd_head_k; + uint32_t n_embd_head_v; +}; + +// Create a mock model from a metadata descriptor (for testing). +// The returned model must be freed with llama_model_free(). +LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc); + +// Returns true if this tensor should be quantized (based on name, dims, params). +LLAMA_API bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor); + +// Compute quantization type assignments for a list of tensors. +// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter). +// result_types: caller-allocated array of n_tensors elements, filled with assigned types. +LLAMA_API void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors); + +// +// device memory querying +// + +// "memory" as in physical memory for a buffer type, in bytes +struct llama_memory_breakdown_data { + size_t model = 0; // memory allocated for the model + size_t context = 0; // memory allocated for the context + size_t compute = 0; // memory allocated for temporary compute buffers + + size_t total() const { + return model + context + compute; + } +}; + +struct llama_device_memory_data { + int64_t total; + int64_t free; + llama_memory_breakdown_data mb; +}; + +// TODO: convert to C-style data structure +using llama_memory_breakdown = std::map; + +LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model); +LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); + +LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); + +LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index aac0d41f2b4..badcbfd0fbb 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #define MAX_REPETITION_THRESHOLD 2000 @@ -454,6 +455,7 @@ const char * llama_grammar_parser::parse_sequence( bool is_nested) { size_t last_sym_start = rule.size(); const char * pos = src; + uint64_t n_prev_rules = 1; // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used // (though it's technically the same as -1 now) @@ -481,6 +483,18 @@ const char * llama_grammar_parser::parse_sequence( // S' ::= S | llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + // Calculate the total number of rules that will be generated by this repetition + uint64_t total_rules = 1; // Start with 1 for the original rule + if (!no_max && max_times > 0) { + total_rules = max_times; + } else if (min_times > 0) { + total_rules = min_times; + } + + if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) { + throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity"); + } + if (min_times == 0) { rule.resize(last_sym_start); } else { @@ -508,12 +522,15 @@ const char * llama_grammar_parser::parse_sequence( if (n_opt > 0) { rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); } + n_prev_rules *= total_rules; + GGML_ASSERT(n_prev_rules >= 1); }; while (*pos) { if (*pos == '"') { // literal string pos++; last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != '"') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -531,6 +548,7 @@ const char * llama_grammar_parser::parse_sequence( start_type = LLAMA_GRETYPE_CHAR_NOT; } last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != ']') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -561,6 +579,7 @@ const char * llama_grammar_parser::parse_sequence( auto token_pair = parse_token(vocab, pos); const char * token_end = token_pair.second; last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({type, token_pair.first}); pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference @@ -568,12 +587,15 @@ const char * llama_grammar_parser::parse_sequence( uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); + uint32_t n_rules_before = symbol_ids.size(); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); + n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); @@ -583,6 +605,7 @@ const char * llama_grammar_parser::parse_sequence( pos = parse_space(pos + 1, is_nested); } else if (*pos == '.') { // any char last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { @@ -830,32 +853,54 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); + llama_grammar_stacks & new_stacks) { + std::vector todo; + todo.push_back(stack); + + auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(), + [](const llama_grammar_element * pa, const llama_grammar_element * pb) { + return pa < pb; // Compare pointer addresses + } + ); + }; + + std::set seen(stack_cmp); + + while (!todo.empty()) { + llama_grammar_stack curr_stack = std::move(todo.back()); + todo.pop_back(); + + if (seen.find( curr_stack) != seen.end()) { + continue; } - return; - } + seen.insert(curr_stack); - const llama_grammar_element * pos = stack.back(); + if (curr_stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + new_stacks.emplace_back(std::move(curr_stack)); + } + continue; + } - switch (pos->type) { + const llama_grammar_element * pos = curr_stack.back(); + + switch (pos->type) { case LLAMA_GRETYPE_RULE_REF: { const size_t rule_id = static_cast(pos->value); const llama_grammar_element * subpos = rules[rule_id].data(); do { // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1); if (!llama_grammar_is_end_of_sequence(pos + 1)) { // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); + next_stack.push_back(pos + 1); } if (!llama_grammar_is_end_of_sequence(subpos)) { // if alternate is nonempty, add to stack - new_stack.push_back(subpos); + next_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + todo.push_back(std::move(next_stack)); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -874,9 +919,9 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR_ANY: case LLAMA_GRETYPE_TOKEN: case LLAMA_GRETYPE_TOKEN_NOT: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); + new_stacks.emplace_back(std::move(curr_stack)); } break; default: @@ -884,6 +929,7 @@ static void llama_grammar_advance_stack( // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on // those GGML_ABORT("fatal error"); + } } } diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 9a215bb77a0..2ff23f87cf4 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -1,6 +1,7 @@ #include "llama-graph.h" #include "llama-impl.h" +#include "llama-model.h" #include "llama-batch.h" #include "llama-cparams.h" @@ -19,7 +20,7 @@ // dedup helpers -static ggml_tensor * build_kq_mask( +static ggml_tensor * build_attn_inp_kq_mask( ggml_context * ctx, const llama_kv_cache_context * mctx, const llama_ubatch & ubatch, @@ -28,7 +29,11 @@ static ggml_tensor * build_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; } static bool can_reuse_kq_mask( @@ -52,6 +57,21 @@ static bool can_reuse_kq_mask( // impl +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -429,6 +449,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -476,6 +504,22 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->get_base()->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->get_base()->set_input_v_rot(self_v_rot); + } + + if (self_k_rot_swa) { + mctx->get_swa()->set_input_k_rot(self_k_rot_swa); + } + + if (self_v_rot_swa) { + mctx->get_swa()->set_input_v_rot(self_v_rot_swa); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -532,6 +576,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { + mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -630,6 +682,22 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } + if (inp_attn->self_k_rot) { + attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); + } + + if (inp_attn->self_k_rot_swa) { + attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa); + } + + if (inp_attn->self_v_rot_swa) { + attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -992,6 +1060,84 @@ ggml_tensor * llm_graph_context::build_norm( return cur; } + +llm_graph_qkv llm_graph_context::build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const { + const int64_t n_embd_q = n_embd_head * n_head; + const int64_t n_embd_kv = n_embd_head * n_head_kv; + + ggml_tensor * Qcur, * Kcur, * Vcur; + + if (layer.wqkv) { + // fused QKV path + ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s); + cb(qkv, "wqkv", il); + if (layer.wqkv_b) { + qkv = ggml_add(ctx0, qkv, layer.wqkv_b); + cb(qkv, "wqkv_b", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(qkv, "wqkv_clamped", il); + } + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q)); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q + n_embd_kv)); + } else { + // separate Q/K/V path + Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur, "Qcur", il); + if (layer.wq_b) { + Qcur = ggml_add(ctx0, Qcur, layer.wq_b); + cb(Qcur, "Qcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur_clamped", il); + } + Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + cb(Kcur, "Kcur", il); + if (layer.wk_b) { + Kcur = ggml_add(ctx0, Kcur, layer.wk_b); + cb(Kcur, "Kcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur_clamped", il); + } + Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Vcur, "Vcur", il); + if (layer.wv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.wv_b); + cb(Vcur, "Vcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur_clamped", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + return { Qcur, Kcur, Vcur }; +} + + ggml_tensor * llm_graph_context::build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -1516,9 +1662,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } + ggml_build_forward_expand(gf, experts); + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; assert(n_expert_used > 0); @@ -1538,6 +1686,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); + + ggml_build_forward_expand(gf, moe_out); } if (hparams.n_expert_used == 1) { @@ -1665,7 +1815,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, - // but this would make the graph topology depend on the number of output tokens, which can interere with + // but this would make the graph topology depend on the number of output tokens, which can interfere with // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { @@ -1940,6 +2090,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -1973,7 +2124,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2002,13 +2153,13 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0); + return inp; } @@ -2024,6 +2175,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2034,6 +2186,15 @@ ggml_tensor * llm_graph_context::build_attn( int il) const { GGML_ASSERT(v_mla == nullptr); + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + + if (inp->self_v_rot) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -2061,11 +2222,20 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -2090,9 +2260,7 @@ static std::unique_ptr build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } @@ -2111,6 +2279,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_k * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2145,10 +2314,15 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -2163,6 +2337,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2171,6 +2346,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + const bool is_swa = hparams.is_swa(il); + + auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot; + auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot; + + if (k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot); + if (k_cur) { + k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot); + } + } + if (v_rot) { + if (v_cur) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot); + } + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -2185,8 +2377,6 @@ ggml_tensor * llm_graph_context::build_attn( const auto * mctx_iswa = inp->mctx; - const bool is_swa = hparams.is_swa(il); - const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // optionally store to KV cache @@ -2211,8 +2401,12 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2243,6 +2437,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2267,7 +2462,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2293,12 +2488,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - ggml_set_name(inp->self_kq_mask, "self_kq_mask"); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -2307,14 +2498,16 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask_swa); - ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); - + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } + inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + + inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0); + inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } @@ -2348,7 +2541,7 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); return output_states; } @@ -2473,9 +2666,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask); - + inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } @@ -2483,9 +2674,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask_swa); - + inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 4855685ef71..5cb1756c6a9 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -17,6 +17,7 @@ struct ggml_context; struct ggml_tensor; struct llama_cparams; +struct llama_layer; struct llama_memory_context_i; @@ -308,6 +309,10 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: assumes v_rot^2 == I + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed // llm_graph_params from a previous batch, causing stack-use-after-return @@ -384,6 +389,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + + ggml_tensor * self_k_rot_swa = nullptr; + ggml_tensor * self_v_rot_swa = nullptr; + const llama_hparams hparams; const llama_cparams cparams; @@ -697,6 +708,12 @@ using llm_graph_result_ptr = std::unique_ptr; // used in build_rs to properly order writes and avoid unnecessary copies using llm_graph_get_rows_fn = std::function; +struct llm_graph_qkv { + ggml_tensor * q; // [n_embd_head, n_head, n_tokens] + ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens] + ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens] +}; + struct llm_graph_context { const llm_arch arch; @@ -783,6 +800,17 @@ struct llm_graph_context { llm_norm_type type, int il) const; + + // compute Q, K, V projections with optional bias and reshape + // supports both fused wqkv and separate wq/wk/wv paths + llm_graph_qkv build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const; + ggml_tensor * build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -882,6 +910,7 @@ struct llm_graph_context { llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -897,6 +926,7 @@ struct llm_graph_context { llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -912,6 +942,7 @@ struct llm_graph_context { llm_graph_input_attn_k * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -928,6 +959,7 @@ struct llm_graph_context { llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional @@ -943,6 +975,7 @@ struct llm_graph_context { llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 78c0bc27d4d..ac7f9ee8650 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -116,6 +116,7 @@ struct llama_hparams { float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; float rope_freq_scale_train_swa = 1.0f; + float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -209,6 +210,9 @@ struct llama_hparams { // qwen3vl deepstack uint32_t n_deepstack_layers = 0; + // gemma4 per-layer embedding + uint32_t n_embd_per_layer = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 4c0188ee722..b3a94b946d2 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -128,7 +128,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); - case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + case GGUF_TYPE_BOOL: return ((const int8_t *)data)[i] != 0 ? "true" : "false"; default: return format("unknown type %d", type); } } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 01166fac9ce..09102f549c8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -13,6 +13,65 @@ #include #include +static bool ggml_is_power_of_2(int n) { + return (n & (n - 1)) == 0; +} + +// orthonormal Walsh-Hadamard rotation matrix +// note: res^2 == I +static void ggml_gen_hadamard(ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + + const int n = tensor->ne[0]; + + assert(ggml_is_power_of_2(n)); + assert(tensor->ne[1] == n); + assert(tensor->ne[2] == 1); + assert(tensor->ne[3] == 1); + + std::vector data_f32; + + float * data = (float *) tensor->data; + + if (tensor->type != GGML_TYPE_F32) { + data_f32.resize(n*n); + data = data_f32.data(); + } + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } + + if (tensor->type != GGML_TYPE_F32) { + ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr); + } +} + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + // // llama_kv_cache // @@ -110,6 +169,18 @@ llama_kv_cache::llama_kv_cache( continue; } + if (n_embd_head_k_all == 0) { + n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); + } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { + n_embd_head_k_all = -1; + } + + if (n_embd_head_v_all == 0) { + n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); + } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { + n_embd_head_v_all = -1; + } + // [TAG_V_CACHE_VARIABLE] const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); @@ -209,6 +280,48 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; + + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + + LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); + LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); + + // pre-compute the haramard matrices and keep them in host memory + // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers + if (attn_rot_k || attn_rot_v) { + for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) { + attn_rot_hadamard[n] = std::vector(n*n); + + ggml_init_params params = { + /* .mem_size = */ 1*ggml_tensor_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + + ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n); + tmp->data = attn_rot_hadamard[n].data(); + + ggml_gen_hadamard(tmp); + } + } + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } @@ -1004,6 +1117,14 @@ bool llama_kv_cache::get_has_shift() const { return result; } +ggml_type llama_kv_cache::type_k() const { + return layers[0].k->type; +} + +ggml_type llama_kv_cache::type_v() const { + return layers[0].v->type; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; @@ -1189,6 +1310,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama return v_idxs; } +ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_k) { + int nrot = 64; + + // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V) + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 + do { + nrot *= 2; + } while (n_embd_head_k_all % nrot == 0); + nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_k_rot"); + } + + return res; +} + +ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_v) { + int nrot = 64; + // using smaller rotation matrices for V seems beneficial + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570 + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_v() % nrot == 0); + //nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_v_rot"); + } + + return res; +} + void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1507,6 +1669,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + +void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -1542,6 +1722,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -1561,17 +1742,22 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // ref: https://github.com/ggml-org/llama.cpp/pull/13870 ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // rotate back + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + // rotate fwd + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1592,6 +1778,9 @@ class llm_graph_input_k_shift : public llm_graph_input_i { ggml_tensor * k_shift; // I32 [kv_size*n_stream] + // note: assumes k_rot^2 == I + ggml_tensor * k_rot = nullptr; + const llama_kv_cache * kv_self; }; @@ -1601,6 +1790,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (k_rot) { + kv_self->set_input_k_rot(k_rot); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { @@ -1612,6 +1805,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + inp->k_rot = build_input_k_rot(ctx); + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1636,7 +1831,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -2240,6 +2435,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +ggml_type llama_kv_cache_context::type_k() const { + return kv->type_k(); +} + +ggml_type llama_kv_cache_context::type_v() const { + return kv->type_v(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } @@ -2264,6 +2467,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con return kv->build_input_v_idxs(ctx, ubatch); } +ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const { + return kv->build_input_v_rot(ctx); +} + void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } @@ -2283,3 +2494,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const { + kv->set_input_v_rot(dst); +} diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 33c78c5f210..0b62dc7b232 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -152,6 +152,9 @@ class llama_kv_cache : public llama_memory_i { bool get_has_shift() const; + ggml_type type_k() const; + ggml_type type_v() const; + // // graph_build API // @@ -191,6 +194,9 @@ class llama_kv_cache : public llama_memory_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; @@ -199,6 +205,9 @@ class llama_kv_cache : public llama_memory_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -226,6 +235,18 @@ class llama_kv_cache : public llama_memory_i { // SWA const uint32_t n_swa = 0; + // env: LLAMA_ATTN_ROT_DISABLE + bool attn_rot_k = false; + bool attn_rot_v = false; + + // if all layers participating in the cache have constant head size, the value is stored here + // otherwise the value is -1 + int32_t n_embd_head_k_all = 0; + int32_t n_embd_head_v_all = 0; + + // pre-computed hadamard martrices + std::unordered_map> attn_rot_hadamard; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; @@ -262,6 +283,7 @@ class llama_kv_cache : public llama_memory_i { ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -328,12 +350,15 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; + ggml_type type_k() const; + ggml_type type_v() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // note: the heads in k_cur and v_cur should be laid out contiguously in memory // - k_cur [n_embd_head_k, n_head_k, n_tokens] // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens] @@ -347,6 +372,9 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -354,6 +382,9 @@ class llama_kv_cache_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: llama_memory_status status; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 411769672af..10e6b459797 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index a1b45e4a3cc..4ce1af592c1 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 6e8413f493d..9287fe45e96 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -1,5 +1,6 @@ #include "llama-memory-recurrent.h" +#include "ggml-backend.h" #include "llama-impl.h" #include "llama-io.h" #include "llama-batch.h" @@ -91,8 +92,8 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); - ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -928,11 +929,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); - // TODO: llama_memory_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); return false; } diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index c03228e9ce2..ed572da7fb5 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -40,6 +40,14 @@ #include #endif +#ifdef _WIN32 +# define llama_mmap_ftell _ftelli64 +# define llama_mmap_fseek _fseeki64 +#else +# define llama_mmap_ftell ftello +# define llama_mmap_fseek fseeko +#endif + // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { @@ -86,6 +94,14 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : owns_fp(false) { + fp = file; + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { LARGE_INTEGER li; li.QuadPart = 0; @@ -159,7 +175,7 @@ struct llama_file::impl { } ~impl() { - if (fp) { + if (fp && owns_fp) { std::fclose(fp); } } @@ -209,9 +225,16 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : fname("(file*)"), owns_fp(false) { + fp = file; + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { if (fd == -1) { - long ret = std::ftell(fp); + off_t ret = llama_mmap_ftell(fp); if (ret == -1) { throw std::runtime_error(format("ftell error: %s", strerror(errno))); } @@ -229,7 +252,7 @@ struct llama_file::impl { void seek(size_t offset, int whence) const { off_t ret = 0; if (fd == -1) { - ret = std::fseek(fp, (long) offset, whence); + ret = llama_mmap_fseek(fp, offset, whence); } else { ret = lseek(fd, offset, whence); } @@ -353,7 +376,7 @@ struct llama_file::impl { ~impl() { if (fd != -1) { close(fd); - } else { + } else if (owns_fp) { std::fclose(fp); } } @@ -369,10 +392,14 @@ struct llama_file::impl { FILE * fp{}; size_t size{}; + bool owns_fp = true; }; llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : pimpl(std::make_unique(fname, mode, use_direct_io)) {} + +llama_file::llama_file(FILE * file) : pimpl(std::make_unique(file)) {} + llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } diff --git a/examples/talk-llama/llama-mmap.h b/examples/talk-llama/llama-mmap.h index 29ce4d24685..b7d5c61e95f 100644 --- a/examples/talk-llama/llama-mmap.h +++ b/examples/talk-llama/llama-mmap.h @@ -15,6 +15,7 @@ using llama_mlocks = std::vector>; struct llama_file { llama_file(const char * fname, const char * mode, bool use_direct_io = false); + llama_file(FILE * file); ~llama_file(); size_t tell() const; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 413f34c2268..4e65a45a50d 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -36,6 +36,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -374,8 +375,9 @@ namespace GGUFMeta { } } else { if (arr_info.gt == GGUF_TYPE_BOOL) { - std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { - return static_cast(x); + const int8_t * values = (const int8_t *) arr_info.data; + std::transform(values, values + arr_info.length, result.begin(), [](int8_t x) { + return static_cast(x != 0); }); } else { std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); @@ -511,6 +513,7 @@ llama_model_loader::llama_model_loader( void * set_tensor_data_ud, const std::string & fname, std::vector & splits, + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, @@ -658,6 +661,36 @@ llama_model_loader::llama_model_loader( LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); } + } else if (file != nullptr) { + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + metadata_ptr.reset(gguf_init_from_file_ptr(file, params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from file pointer", __func__)); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(file)); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the main file. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); + } } else { get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); @@ -669,7 +702,7 @@ llama_model_loader::llama_model_loader( fver = (enum llama_fver) gguf_get_version(metadata); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", - __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + __func__, n_kv, n_tensors, fname.empty() ? "(file*)" : fname.c_str(), llama_file_version_name(fver)); // determine file type based on the number of tensors for each quantization and print meta data // TODO: make optional @@ -726,6 +759,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; + case GGML_TYPE_Q1_0: ftype = LLAMA_FTYPE_MOSTLY_Q1_0; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -1127,6 +1161,12 @@ struct ggml_tensor * llama_model_loader::create_tensor( if (overrides->buft == ggml_backend_cpu_buffer_type()) { // when overriding to a CPU buffer, consider the extra buffer types buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu); + if (use_mmap) { + static std::once_flag once; + std::call_once(once, [] { + LLAMA_LOG_WARN("llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance\n"); + }); + } } else { buft = overrides->buft; } diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index ed5de729caf..7b3d6703c03 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -125,6 +125,7 @@ struct llama_model_loader { void * set_tensor_data_ud, const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 6f6538aeccd..26864c18e97 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -1,7 +1,9 @@ #include "llama-model-saver.h" +#include "ggml.h" #include "gguf.h" +#include "llama-arch.h" #include "llama.h" #include "llama-hparams.h" #include "llama-model.h" @@ -10,8 +12,33 @@ #include #include +bool llama_model_saver_supports_arch(llm_arch arch) { + switch (arch) { + case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + case LLM_ARCH_PLAMO3: + case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_COHERE2: + case LLM_ARCH_OLMO2: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_EXAONE_MOE: + case LLM_ARCH_AFMOE: + case LLM_ARCH_APERTUS: + case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: + return false; + default: + return true; + } +} + llama_model_saver::llama_model_saver(const struct llama_model * model) : - gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) {} + gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) { + GGML_ASSERT(llama_model_saver_supports_arch(model->arch)); +} llama_model_saver::llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx) : gguf_ctx(gguf_ctx == nullptr ? gguf_init_empty() : gguf_ctx), gguf_ctx_owned(gguf_ctx == nullptr), model(nullptr), llm_kv(arch) {} @@ -105,7 +132,10 @@ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { return; } if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { - GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + const std::string tensor_name = tensor->name; + GGML_ASSERT( + tensor_name == "rope_freqs.weight" || tensor_name == "rope_factors_long.weight" || + tensor_name == "rope_factors_short.weight"); // FIXME return; } gguf_add_tensor(gguf_ctx, tensor); @@ -127,6 +157,7 @@ void llama_model_saver::add_kv_from_model() { tokens[id] = token_data.text; scores[id] = token_data.score; + // FIXME should this be treated as flags? switch(token_data.attr) { case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; @@ -134,6 +165,9 @@ void llama_model_saver::add_kv_from_model() { case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + // case LLAMA_TOKEN_ATTR_NORMALIZED: ??? + // case LLAMA_TOKEN_ATTR_LSTRIP: ??? + // case LLAMA_TOKEN_ATTR_RSTRIP: ??? case LLAMA_TOKEN_ATTR_UNDEFINED: default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; } @@ -144,6 +178,19 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_GENERAL_ARCHITECTURE, model->arch_name()); // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + // add_kv(LLM_KV_GENERAL_FILE_TYPE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_SEQUENCE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_K, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIN_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TEMP, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, ???); add_kv(LLM_KV_GENERAL_NAME, model->name); // add_kv(LLM_KV_GENERAL_AUTHOR, ???); // add_kv(LLM_KV_GENERAL_VERSION, ???); @@ -163,17 +210,31 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + add_kv(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp); + add_kv(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp); add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups); + add_kv(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used); add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm); + add_kv(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers); + add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer); add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping); add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); @@ -181,6 +242,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); + add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -188,22 +252,39 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full); add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full); - add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); - add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + add_kv(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + add_kv(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + add_kv(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate); add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full); add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa); + add_kv(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections); add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); @@ -211,6 +292,10 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + add_kv(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow); // TODO: implement split file support // add_kv(LLM_KV_SPLIT_NO, ???); @@ -221,8 +306,11 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + add_kv(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); @@ -260,15 +348,39 @@ void llama_model_saver::add_kv_from_model() { // TODO: implement LoRA support // add_kv(LLM_KV_ADAPTER_TYPE, ???); // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + // add_kv(LLM_KV_ADAPTER_LORA_TASK_NAME, ???); + // add_kv(LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, ???); + // add_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, ???); + + add_kv(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + add_kv(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + add_kv(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + add_kv(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + + add_kv(LLM_KV_CLASSIFIER_OUTPUT_LABELS, model->classifier_labels); + + add_kv(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + + add_kv(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n); + add_kv(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p); + add_kv(LLM_KV_XIELU_BETA, hparams.xielu_beta); + add_kv(LLM_KV_XIELU_EPS, hparams.xielu_eps); // deprecated // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); + + add_kv(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in); + add_kv(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out); + add_kv(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in); + add_kv(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out); } void llama_model_saver::add_tensors_from_model() { - if (std::string(model->output->name) != std::string(model->tok_embd->name)) { + if (model->output != nullptr && + std::string(model->output->name) != std::string(model->tok_embd->name)) { add_tensor(model->tok_embd); // some models use the same tensor for tok_embd and output } add_tensor(model->type_embd); @@ -297,3 +409,6 @@ void llama_model_saver::save(const std::string & path_model) { gguf_write_to_file(gguf_ctx, path_model.c_str(), false); } +void llama_model_saver::save(FILE * file) { + gguf_write_to_file_ptr(gguf_ctx, file, false); +} diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h index 2b3541ce6c5..36a715e2b6b 100644 --- a/examples/talk-llama/llama-model-saver.h +++ b/examples/talk-llama/llama-model-saver.h @@ -6,6 +6,9 @@ #include +// FIXME temporary function for better error messages +bool llama_model_saver_supports_arch(llm_arch arch); + struct llama_model_saver { struct gguf_context * gguf_ctx = nullptr; const bool gguf_ctx_owned; @@ -37,4 +40,5 @@ struct llama_model_saver { void add_tensors_from_model(); void save(const std::string & path_model); + void save(FILE * file); }; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index e8e1bbf1cd1..9e2a13cbd43 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,6 +1,8 @@ #include "llama-model.h" -#include "ggml.h" +#include "llama-arch.h" +#include "llama-ext.h" +#include "llama-hparams.h" #include "llama-impl.h" #include "llama-mmap.h" #include "llama-cparams.h" @@ -12,10 +14,11 @@ #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" -#include "ggml-cpp.h" - #include "models/models.h" +#include "ggml.h" +#include "ggml-cpp.h" + #include #include #include @@ -24,9 +27,358 @@ #include #include #include +#include #include #include #include +#include +#include + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { + const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + const llama_hparams & hparams = ud->model->hparams; + const std::string tensor_name = tensor->name; + + const std::regex pattern_q_weight ("blk\\.\\d*\\.attn_q.weight"); + const std::regex pattern_kv_weight ("blk\\.\\d*\\.attn_(k|v).weight"); + const std::regex pattern_qkv_weight ("blk\\.\\d*\\.attn_qkv.weight"); + const std::regex pattern_q_bias ("blk\\.\\d*\\.attn_q\\.bias"); + const std::regex pattern_kv_bias ("blk\\.\\d*\\.attn_(k|v)\\.bias"); + const std::regex pattern_qkv_bias ("blk\\.\\d*\\.attn_qkv.bias"); + const std::regex pattern_qk_norm ("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + const std::regex pattern_kv_cache ("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks ("blk\\.\\d*\\.attn_sinks.weight"); + const std::regex pattern_attn_out_weight ("blk\\.\\d*\\.attn_output.weight"); + const std::regex pattern_attn_out_bias ("blk\\.\\d*\\.attn_output.bias"); + const std::regex pattern_attn_gate_weight("blk\\.\\d*\\.attn_gate.weight"); + + const std::regex pattern_ssm_dt ("blk\\.\\d*\\.ssm_dt.bias"); + const std::regex pattern_ssm_a ("blk\\.\\d*\\.ssm_a"); + const std::regex pattern_ssm_alpha ("blk\\.\\d*\\.ssm_alpha.weight"); + const std::regex pattern_ssm_beta ("blk\\.\\d*\\.ssm_beta.weight"); + const std::regex pattern_ssm_beta_alpha ("blk\\.\\d*\\.ssm_ba.weight"); + const std::regex pattern_r_cache ("cache_r_l\\d*"); + const std::regex pattern_s_cache ("cache_s_l\\d*"); + const std::regex pattern_ssm_conv1d ("blk\\.\\d*\\.ssm_conv1d.weight"); + const std::regex pattern_ssm_out_weight ("blk\\.\\d*\\.ssm_out.weight"); + + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + const std::regex pattern_ffn_up_gate_bias ("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + const std::regex pattern_ffn_gate_up_weight("blk\\.\\d*\\.ffn_gate_up(_exps)?.weight"); + const std::regex pattern_ffn_down_weight ("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + const std::regex pattern_ffn_down_bias ("blk\\.\\d*\\.ffn_down.bias"); + const std::regex pattern_ffn_down_exps_bias("blk\\.\\d*\\.ffn_down_exps.bias"); + + const std::regex pattern_output_weight("output\\.weight"); + const std::regex pattern_output_bias ("output\\.bias"); + + struct tensor_config { + ggml_backend_meta_split_axis axis; + + const ggml_tensor * tensor_axis_0; + + uint32_t il; + size_t rotation; // when assigning tensor slices, rotate how the rounding is done for more even allocation + }; + + auto get_tensor_config_impl = [&]( + const ggml_backend_meta_split_axis axis, const std::string & suffix = "", const std::string & suffix_fallback = "") -> tensor_config { + // the layers in a tensor can be inhomogeneous, if the pattern is cleanly divided by the number of GPUs there can be aliasing effects, + // count only the same type of previous layers to avoid this + auto get_il_eff = [&](const size_t il){ + size_t ret = 0; + const bool il_is_recurrent = hparams.is_recurrent(il); + const bool il_is_swa = hparams.is_swa(il); + for (size_t il_prev = 0; il_prev < il; il_prev++) { + ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + } + return ret; + }; + + uint32_t il; + std::string prefix; + size_t rotation; + if (tensor_name.substr(0, 4) == "blk.") { + const size_t length_prefix = tensor_name.find('.', 4); + GGML_ASSERT(length_prefix != std::string::npos); + prefix = tensor_name.substr(0, length_prefix + 1); + il = std::stoull(tensor_name.substr(4, length_prefix)); + rotation = get_il_eff(il) % ud->n_devices; + } else if (tensor_name.substr(0, 6) == "cache_") { + const size_t layer_index_start = tensor_name.find("_l", 6); + GGML_ASSERT(layer_index_start != std::string::npos); + il = std::stoull(tensor_name.substr(layer_index_start + 2)); + prefix = "blk." + std::to_string(il) + "."; + rotation = get_il_eff(il) % ud->n_devices; + } else { + il = 0; + rotation = hparams.n_layer % ud->n_devices; + } + const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); + if (tensor_axis_0 == nullptr) { + GGML_ASSERT(!suffix_fallback.empty()); + tensor_axis_0 = ud->model->get_tensor((prefix + suffix_fallback).c_str()); + } + GGML_ASSERT(tensor_axis_0 != nullptr); + return {axis, tensor_axis_0, il, rotation}; + }; + + auto get_tensor_config = [&]() -> tensor_config { + // standard attention + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_qkv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if ( std::regex_match(tensor_name, pattern_qkv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_qk_norm)) { + return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_kv_cache) || std::regex_match(tensor_name, pattern_attn_sinks)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_attn_out_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + + if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta) || + std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_r_cache) || std::regex_match(tensor_name, pattern_s_cache)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_conv1d)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_up_gate_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + if (std::regex_match(tensor_name, pattern_ffn_down_exps_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_PARTIAL); + } + + // output + if (std::regex_match(tensor_name, pattern_output_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_output_bias)) { + const ggml_tensor * output_weight = ud->model->get_tensor("output.weight"); + GGML_ASSERT(output_weight != nullptr); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // everything else + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + }; + + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector { + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + + // both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different: + // - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern) + // - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern) + if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return {key_dim, key_dim, value_dim}; + } + } else { + const int64_t head_ratio = n_v_heads / n_k_heads; + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return std::vector(2 + head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(head_ratio, n_k_heads); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(head_ratio, n_k_heads * head_v_dim * head_v_dim); + } + } + + // the FFN is the same for Qwen 3 Next and Qwen 3.5: + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {n_ff_exp, n_ff_exp}; + } + return {tensor->ne[axis]}; + } + + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); + GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); + GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); + return {n_embd, n_embd_gqa, n_embd_gqa}; + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {n_ff_exp, n_ff_exp}; + } + return {tensor->ne[axis]}; + }; + + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector & segments) -> std::vector { + if (hparams.is_recurrent(il)) { + // linear attention + const int64_t head_dim = hparams.ssm_d_state; + const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || + std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(segments.size(), granularity_qkv); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(segments.size(), granularity_qkv / head_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return std::vector(segments.size(), 2 * (granularity_qkv / head_dim)); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(segments.size(), granularity_qkv * head_dim); + } + } else { + // regular attention + const uint32_t n_gqa = hparams.n_gqa(il); + const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + if (std::regex_match(tensor_name, pattern_attn_sinks)) { + GGML_ASSERT(segments.size() == 1); + return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + } + + const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { + GGML_ASSERT(segments.size() == 1); + // some models have Q gate tensors, for those cases the granularity needs to be doubled: + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + return {std::lcm(2*n_embd_q, blck_size)}; + } + return {granularity_q}; + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_q}; + } + + const int64_t granularity_kv = granularity_q / n_gqa; + if (std::regex_match(tensor_name, pattern_kv_weight) || + std::regex_match(tensor_name, pattern_kv_bias) || + std::regex_match(tensor_name, pattern_kv_cache)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_kv}; + } + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + GGML_ASSERT(segments.size() == 3); + return {granularity_q, granularity_kv, granularity_kv}; + } + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || + std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { + GGML_ASSERT(segments.size() <= 2); + return std::vector(segments.size(), blck_size); + } + + // everything else + GGML_ASSERT(segments.size() == 1); + return {1}; + }; + + ggml_backend_meta_split_state split_state; + memset(&split_state, 0, sizeof(split_state)); + tensor_config tc = get_tensor_config(); + split_state.axis = tc.axis; + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + const int64_t ne_full = tensor->ne[split_state.axis]; + const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); + const float * tensor_split = ud->model->tensor_split(); + std::vector tensor_split_scan; + tensor_split_scan.reserve(ud->n_devices); + for (size_t j = 0; j < ud->n_devices; j++) { + tensor_split_scan.push_back(tensor_split == nullptr ? 0.0f : tensor_split[(j + tc.rotation) % ud->n_devices]); + if (j > 0) { + tensor_split_scan[j] += tensor_split_scan[j - 1]; + } + } + const std::vector segments = get_split_segments(split_state.axis, tc.il); + const std::vector granularity = get_split_granularity(blck_size, tc.il, segments); + for (size_t is = 0; is < segments.size(); is++) { + const int64_t ne_s = segments[is]; + const int64_t g_s = granularity[is]; + GGML_ASSERT(ne_full % g_s == 0); + int64_t low = 0; + size_t j = 0; + for (; j < ud->n_devices - 1; j++) { + int64_t high = tensor_split_scan.back() == 0.0f ? + ne_s * (j+1)/ud->n_devices : ne_s * tensor_split_scan[j]/tensor_split_scan.back(); + if (high % g_s != 0) { + high -= high % g_s; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = high - low; + low = high; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + } + split_state.n_segments = segments.size(); + } else { + memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.n_segments = 1; + } + return split_state; + GGML_UNUSED(userdata); +} const char * llm_type_name(llm_type type) { switch (type) { @@ -93,6 +445,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_26B: return "26B"; case LLM_TYPE_27B: return "27B"; case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_31B: return "31B"; case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; case LLM_TYPE_35B: return "35B"; @@ -127,6 +480,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; + case LLM_TYPE_26B_A4B: return "26B.A4B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_35B_A3B: return "35B.A3B"; @@ -181,7 +535,7 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -203,10 +557,10 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer if (!no_host) { - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + for (const auto & dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev.dev); if (buft) { - buft_list.emplace_back(dev, buft); + buft_list.emplace_back(dev.dev, buft); break; } } @@ -273,14 +627,16 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s // add the device extra buffer type (if any) ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); - - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(dev, *extra_bufts); - ++extra_bufts; + if (reg) { + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(dev, *extra_bufts); + ++extra_bufts; + } } } @@ -342,6 +698,9 @@ void llama_model::load_arch(llama_model_loader & ml) { if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } + if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { + throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); + } } void llama_model::load_hparams(llama_model_loader & ml) { @@ -370,12 +729,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); + if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) { + if (hparams.n_expert <= 1) { + hparams.n_expert = 0; + hparams.n_expert_used = 0; + } + } + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); @@ -454,6 +822,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false); // non-transformer models do not have attention heads if (hparams.n_head() > 0) { @@ -748,8 +1117,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: @@ -781,8 +1148,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 12: @@ -797,8 +1162,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { @@ -810,8 +1173,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 24: @@ -823,8 +1184,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); if (hparams.n_layer == 12 && hparams.n_embd == 768) { @@ -838,8 +1197,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NEO_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 28) { type = LLM_TYPE_250M; @@ -848,8 +1205,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_EUROBERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 12) { type = LLM_TYPE_SMALL; // 0.2B @@ -913,7 +1268,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // fall through case LLM_ARCH_QWEN2: { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; @@ -940,8 +1294,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Set non-causal attention for diffusion models hparams.causal_attn = false; - } - break; + } break; case LLM_ARCH_LLADA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -955,8 +1308,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Set non-causal attention for diffusion models hparams.causal_attn = false; - } - break; + } break; case LLM_ARCH_LLADA_MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -995,7 +1347,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3: { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; @@ -1275,6 +1626,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA4: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_26B_A4B; break; + case 35: type = LLM_TYPE_E2B; break; + case 42: type = LLM_TYPE_E4B; break; + case 60: type = LLM_TYPE_31B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GEMMA_EMBEDDING: { hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; @@ -1287,7 +1666,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); //applied only if model converted with --sentence-transformers-dense-modules ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); @@ -1587,6 +1965,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); @@ -1623,7 +2002,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // (optional) temperature tuning - used by mistral-large ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? hparams.f_attn_temp_offset = 0.0f; @@ -1635,6 +2014,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK2OCR: + { + // similar to deepseek2, but without MLA + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1672,6 +2071,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters (GLM-OCR) ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1705,6 +2105,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1751,6 +2152,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1925,6 +2327,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); switch (hparams.n_layer) { case 32: type = LLM_TYPE_30B_A3B; break; @@ -2053,7 +2456,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_embd) { case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; case 2048: case 2560: type = LLM_TYPE_3B; break; case 4096: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -2079,7 +2482,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); } break; case LLM_ARCH_BAILINGMOE: { @@ -2107,6 +2509,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -2197,9 +2600,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; @@ -2588,11 +3000,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // build a list of buffer types for the CPU and GPU devices pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); - for (auto * dev : devices) { - buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + for (const auto & dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); // add CPU buffer types as a fallback buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); - pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); + pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); @@ -2606,7 +3018,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (all_zero) { // default split, by free memory for (size_t i = 0; i < n_devices(); ++i) { - ggml_backend_dev_t dev = devices[i]; + ggml_backend_dev_t dev = devices[i].dev; size_t total; size_t free; ggml_backend_dev_memory(dev, &free, &total); @@ -2642,7 +3054,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return {cpu_dev, &pimpl->cpu_buft_list}; } const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); - auto * dev = devices.at(layer_gpu); + auto * dev = devices.at(layer_gpu).dev; LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); return {dev, &pimpl->gpu_buft_list.at(dev)}; }; @@ -2708,6 +3120,25 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); } }; + + // helper: try to load merged qkv first, fall back to separate q, k, v + auto create_tensor_qkv = [&](llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags) { + const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (layer.wqkv) { + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } + }; + switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -2733,16 +3164,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2805,7 +3231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -2841,9 +3267,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -2882,9 +3306,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2928,7 +3350,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); const int64_t n_ff = hparams.n_ff(i); const int64_t n_head = hparams.n_head(i); const int64_t n_head_kv = hparams.n_head_kv(i); @@ -2941,17 +3362,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { else if (n_head_kv > 0) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); } // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); if (n_ff > 0) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3043,9 +3459,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); @@ -3108,9 +3522,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3175,10 +3587,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3211,28 +3623,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - if (!layer.wqkv) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3259,7 +3659,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MODERN_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3325,9 +3725,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3342,31 +3740,24 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; // JinaBertLayer - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3394,8 +3785,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BLOOM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3414,10 +3805,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3450,10 +3841,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -3490,16 +3881,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - // optional q and k layernorms, present in StableLM 2 12B layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); @@ -3527,7 +3911,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3557,16 +3941,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3587,16 +3964,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); @@ -3645,9 +4015,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -3678,9 +4046,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -3721,22 +4087,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -3763,7 +4117,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -3793,19 +4147,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); @@ -3832,9 +4176,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3971,10 +4313,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4006,11 +4348,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4036,9 +4377,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4062,9 +4401,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4086,9 +4423,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4110,9 +4445,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4147,9 +4480,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4175,13 +4506,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0); + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -4190,9 +4522,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); @@ -4219,6 +4549,101 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA4: + { + const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; + const int64_t n_ff_exp = hparams.n_ff_exp; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + if (n_embd_per_layer > 0) { + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); + } + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_embd_k = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v = hparams.n_embd_v_gqa(i); + const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); + + if (!hparams.is_swa(i)) { + // full_attention layers use rope_freqs for proportional rope + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + // handle use_double_wide_mlp + int64_t n_ff_cur = hparams.n_ff(i); + + // for expert layers, we use normal FFN as shared expert (same as python code) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // MoE router + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + bool has_expert = layer.ffn_gate_inp != nullptr; + + // norm + if (has_expert) { + layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); + + layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); + + // MoE FFN + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + + // per-expert scale will be loaded as down_exps_s at the end of the current switch case + } + + // per-layer embeddings + if (n_embd_per_layer > 0) { + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + } + } + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4239,16 +4664,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4414,9 +4834,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); } @@ -4492,14 +4910,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_head_i = hparams.n_head(i); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } // feed forward (w/ optional biases) @@ -4542,9 +4955,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4572,9 +4983,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -4597,9 +5006,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); @@ -4622,9 +5029,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -4645,9 +5050,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); @@ -4678,14 +5081,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0); + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4709,9 +5107,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); @@ -4778,10 +5174,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4811,9 +5207,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4850,9 +5244,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4883,6 +5275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { const bool is_mla = hparams.is_mla(); @@ -4960,6 +5353,60 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_DEEPSEEK2OCR: + { + // similar to deepseek2, but without MLA + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); @@ -5145,10 +5592,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5187,10 +5634,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // attention biases - all have shape n_embd (output dimension of projections) - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5218,17 +5665,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - } + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -5261,17 +5698,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - } + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); @@ -5329,12 +5756,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); // GLM-style attention with bias terms - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -5514,16 +5936,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5590,14 +6007,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_head_i = hparams.n_head(i); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -5645,9 +6057,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -5673,9 +6083,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -5718,9 +6126,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); @@ -5773,8 +6179,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -5888,8 +6294,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -6044,9 +6450,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6060,8 +6464,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); // posnet { @@ -6126,8 +6530,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); // convnext { @@ -6175,9 +6579,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6278,9 +6680,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6333,9 +6733,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6370,9 +6768,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); // attention projections - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // Q/K normalization @@ -6430,16 +6826,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6519,14 +6910,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { /*ATTENTION LAYERS*/ // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0); + create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); @@ -6560,9 +6946,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6580,6 +6964,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6597,9 +6982,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6631,9 +7014,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6658,9 +7039,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); @@ -6670,11 +7049,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - // bias - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); @@ -6722,9 +7097,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); } else { @@ -6756,9 +7129,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -6795,9 +7166,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6841,16 +7210,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); @@ -6874,9 +7238,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6933,9 +7295,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // q, k, v projections // Python: q_proj, k_proj, v_proj - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); // KDA specific projections // f_a_proj, f_b_proj @@ -7081,16 +7441,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); // weight tensors - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -7147,9 +7502,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7213,9 +7566,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7278,9 +7629,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7319,9 +7668,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); uint32_t n_head = hparams.n_head(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -7380,9 +7727,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); // head-wise attention gate (Step35 self_attn.g_proj) @@ -7426,9 +7771,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -7501,6 +7844,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_in_s && layer.ssm_in) { + layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } if (!layer.ssm_out_s && layer.ssm_out) { layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); } @@ -7510,11 +7856,77 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + + // input scales + if (!layer.wq_in_s && layer.wq) { + layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_in_s && layer.wk) { + layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_in_s && layer.wv) { + layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_in_s && layer.wo) { + layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_in_s && layer.wqkv) { + layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { + layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_in_s && layer.ffn_gate) { + layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_in_s && layer.ffn_down) { + layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_in_s && layer.ffn_up) { + layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { + layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { + layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_in_in_s && layer.ssm_in) { + layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_in_s && layer.ssm_out) { + layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { + layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_in_s && layer.ssm_beta) { + layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } ml.done_getting_tensors(); + // populate tensors_by_name + for (auto & [_, ctx_ptr] : ml.ctx_map) { + for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); pimpl->mappings.reserve(ml.mappings.size()); @@ -7597,14 +8009,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { buf_map.emplace(idx, buf); } } - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - for (auto & buf : buf_map) { + for (auto & buf : bufs) { // indicate that this buffer contains weights // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight - ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); + ctx_buf_maps.emplace_back(ctx, buf_map); } @@ -7632,13 +8045,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } - // populate tensors_by_name - for (auto & [ctx, _] : pimpl->ctxs_bufs) { - for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { - tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - } - if (ml.no_alloc) { return true; } @@ -7683,6 +8089,10 @@ size_t llama_model::n_devices() const { return devices.size(); } +const float * llama_model::tensor_split() const { + return params.tensor_split; +} + uint32_t llama_model::n_gpu_layers() const { return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; } @@ -7801,114 +8211,114 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; - for (auto label : classifier_labels) { + for (const auto & label : classifier_labels) { LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } - } - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } } vocab.print_info(); @@ -8105,7 +8515,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { llama_memory_i::layer_reuse_cb reuse = nullptr; - if (arch == LLM_ARCH_GEMMA3N) { + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { if (il >= (int32_t) hparams.n_layer_kv_from_start) { return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); @@ -8168,9 +8578,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LLAMA4: { if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique>(*this, params); + llm = std::make_unique>(*this, params); } else { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } } break; case LLM_ARCH_LLAMA_EMBED: @@ -8248,23 +8658,19 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_DREAM: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_LLADA: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_LLADA_MOE: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_RND1: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -8358,6 +8764,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GEMMA4: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_GEMMA_EMBEDDING: { llm = std::make_unique(*this, params); @@ -8424,7 +8834,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_GLM_DSA: + case LLM_ARCH_MISTRAL4: { llm = std::make_unique(*this, params); } break; @@ -8448,11 +8860,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { switch (params.gtype) { case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); break; case LLM_GRAPH_TYPE_DEFAULT: case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); break; default: GGML_ABORT("invalid graph type"); @@ -8460,9 +8872,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_T5ENCODER: { - llm = std::make_unique(*this, params); - } - break; + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_JAIS: { llm = std::make_unique(*this, params); @@ -8574,6 +8985,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { llm = std::make_unique(*this, params); @@ -8823,6 +9235,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -8836,6 +9249,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: case LLM_ARCH_GLM_DSA: @@ -8874,6 +9288,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA4: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -8920,6 +9335,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_HUNYUAN_VL: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); @@ -9054,3 +9472,18 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +int32_t llama_model_n_expert(const struct llama_model * model) { + return model->hparams.n_expert; +} + +int32_t llama_model_n_devices(const struct llama_model * model) { + return (int32_t)model->devices.size(); +} + +ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) { + if (i < 0 || i >= (int)model->devices.size()) { + return nullptr; + } + return model->devices[i].dev; +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 25bf892e7e2..5f101bd6374 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -84,6 +84,7 @@ enum llm_type { LLM_TYPE_26B, LLM_TYPE_27B, LLM_TYPE_30B, + LLM_TYPE_31B, LLM_TYPE_32B, LLM_TYPE_34B, LLM_TYPE_35B, @@ -118,6 +119,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe + LLM_TYPE_26B_A4B, // Gemma4 LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, LLM_TYPE_35B_A3B, // Qwen3.5 @@ -244,6 +246,8 @@ struct llama_layer { struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -254,13 +258,6 @@ struct llama_layer { struct ggml_tensor * wo_enc = nullptr; struct ggml_tensor * wqkv_gate = nullptr; - // attention bias - struct ggml_tensor * bq = nullptr; - struct ggml_tensor * bk = nullptr; - struct ggml_tensor * bv = nullptr; - struct ggml_tensor * bo = nullptr; - struct ggml_tensor * bqkv = nullptr; - // relative position bias struct ggml_tensor * attn_rel_b = nullptr; struct ggml_tensor * attn_rel_b_enc = nullptr; @@ -270,6 +267,9 @@ struct llama_layer { struct ggml_tensor * ffn_norm = nullptr; struct ggml_tensor * ffn_norm_b = nullptr; struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * ffn_post_norm_1 = nullptr; // gemma4 + struct ggml_tensor * ffn_post_norm_2 = nullptr; // gemma4 + struct ggml_tensor * ffn_pre_norm_2 = nullptr; // gemma4 struct ggml_tensor * layer_out_norm = nullptr; struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; @@ -285,6 +285,7 @@ struct llama_layer { // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_inp_s = nullptr; // gemma4 struct ggml_tensor * ffn_gate_exps = nullptr; struct ggml_tensor * ffn_down_exps = nullptr; struct ggml_tensor * ffn_up_exps = nullptr; @@ -409,10 +410,32 @@ struct llama_layer { struct ggml_tensor * ffn_gate_shexp_s = nullptr; struct ggml_tensor * ffn_up_shexp_s = nullptr; struct ggml_tensor * ffn_down_shexp_s = nullptr; - struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_in_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; struct ggml_tensor * ssm_alpha_s = nullptr; struct ggml_tensor * ssm_beta_s = nullptr; + // input scales + struct ggml_tensor * wq_in_s = nullptr; + struct ggml_tensor * wk_in_s = nullptr; + struct ggml_tensor * wv_in_s = nullptr; + struct ggml_tensor * wo_in_s = nullptr; + struct ggml_tensor * wqkv_in_s = nullptr; + struct ggml_tensor * wqkv_gate_in_s = nullptr; + struct ggml_tensor * ffn_gate_in_s = nullptr; + struct ggml_tensor * ffn_up_in_s = nullptr; + struct ggml_tensor * ffn_down_in_s = nullptr; + struct ggml_tensor * ffn_gate_exps_in_s = nullptr; + struct ggml_tensor * ffn_down_exps_in_s = nullptr; + struct ggml_tensor * ffn_up_exps_in_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_in_s= nullptr; + struct ggml_tensor * ffn_up_shexp_in_s = nullptr; + struct ggml_tensor * ffn_down_shexp_in_s= nullptr; + struct ggml_tensor * ssm_in_in_s = nullptr; + struct ggml_tensor * ssm_out_in_s = nullptr; + struct ggml_tensor * ssm_alpha_in_s = nullptr; + struct ggml_tensor * ssm_beta_in_s = nullptr; + // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; struct ggml_tensor * per_layer_proj = nullptr; @@ -461,6 +484,9 @@ struct llama_layer { struct ggml_tensor * indexer_attn_k = nullptr; struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + // gemma4 layer output scale + struct ggml_tensor * out_scale = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -470,6 +496,19 @@ struct llama_layer { struct llama_layer_nextn nextn; }; +struct llama_device { + bool is_meta; + + ggml_backend_dev_t dev; +}; + +struct llama_meta_device_get_split_state_userdata { + size_t n_devices; + const struct llama_model * model; +}; + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); + struct llama_model { llm_type type = LLM_TYPE_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -505,9 +544,9 @@ struct llama_model { struct ggml_tensor * conv1d_b = nullptr; // gemma3n altup - struct ggml_tensor * tok_embd_per_layer = nullptr; struct ggml_tensor * altup_proj = nullptr; struct ggml_tensor * altup_unembd_proj = nullptr; + struct ggml_tensor * per_layer_tok_embd = nullptr; struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; @@ -524,7 +563,7 @@ struct llama_model { std::unordered_map gguf_kv; // list of devices used in this model - std::vector devices; + std::vector devices; // for quantize-stats only std::vector> tensors_by_name; @@ -532,6 +571,9 @@ struct llama_model { // for keeping track of associated LoRA adapters std::unordered_set loras; + // statically allocated context for assigning + struct llama_meta_device_get_split_state_userdata get_split_state_ud; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -552,6 +594,7 @@ struct llama_model { size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + const float * tensor_split() const; uint32_t n_gpu_layers() const; llama_split_mode split_mode() const; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 8e8ce231249..2f0f70b73b6 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -1,11 +1,11 @@ -#include "llama.h" #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" +#include "llama-ext.h" +#include #include #include -#include #include #include #include @@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::maptensor_types) { - const auto & tensor_types = *static_cast *>(params->tensor_types); - for (const auto & [tname, qtype] : tensor_types) { - tensor_type_patterns.emplace_back(std::regex(tname), qtype); + if (params->tt_overrides) { + for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) { + tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type); } } } @@ -199,6 +197,7 @@ struct quantize_state_impl { // per-tensor metadata, computed in the preliminary loop and used in the main loop struct tensor_metadata { + std::string name; ggml_type target_type; tensor_category category; std::string remapped_imatrix_name; @@ -344,7 +343,13 @@ static bool tensor_allows_quantization(const llama_model_quantize_params * param quantize &= name.find("attn_rel_b.weight") == std::string::npos; // do not quantize specific multimodal tensors - quantize &= name.find(".position_embd.") == std::string::npos; + quantize &= name.find(".position_embd") == std::string::npos; + quantize &= name.find("sam.pos_embd") == std::string::npos; + quantize &= name.find("sam.neck.") == std::string::npos; + quantize &= name.find("sam.net_") == std::string::npos; + quantize &= name.find(".rel_pos") == std::string::npos; + quantize &= name.find(".patch_embd") == std::string::npos; + quantize &= name.find(".patch_merger") == std::string::npos; return quantize; } @@ -678,9 +683,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; - manual = true; - break; } + manual = true; + break; } } } @@ -784,7 +789,7 @@ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type ds // given a file type, get the default tensor type // -static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { +ggml_type llama_ftype_get_default_type(llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0; case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1; @@ -794,6 +799,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16; case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; + case LLAMA_FTYPE_MOSTLY_Q1_0: return GGML_TYPE_Q1_0; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; @@ -823,16 +829,32 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_S: case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S; - default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + default: return GGML_TYPE_COUNT; } } + +static void init_quantize_state_counters(quantize_state_impl & qs, std::vector & metadata) { + for (auto & tm : metadata) { + tensor_category cat = tensor_get_category(tm.name); + tm.category = cat; + + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } + } + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; +} + // // main quantization driver // static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type default_type; llama_ftype ftype = params->ftype; int nthread = params->nthread; @@ -841,7 +863,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } - default_type = llama_ftype_get_default_type(ftype); + ggml_type default_type = llama_ftype_get_default_type(ftype); + if (default_type == GGML_TYPE_COUNT) { + throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. @@ -851,15 +876,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: constexpr bool use_mmap = false; #endif - llama_model_kv_override * kv_overrides = nullptr; - if (params->kv_overrides) { - auto * v = (std::vector*)params->kv_overrides; - kv_overrides = v->data(); - } - + const llama_model_kv_override * kv_overrides = params->kv_overrides; std::vector splits = {}; llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr, - fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -873,9 +893,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->only_copy) { ftype = ml.ftype; } + std::unordered_map> i_data; const std::unordered_map> * imatrix_data = nullptr; if (params->imatrix) { - imatrix_data = static_cast>*>(params->imatrix); + for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) { + i_data.emplace(p->name, std::vector(p->data, p->data + p->size)); + } + imatrix_data = & i_data; if (imatrix_data) { LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n", __func__, (int)imatrix_data->size()); @@ -896,7 +920,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector prune_list = {}; if (params->prune_layers) { - prune_list = *static_cast *>(params->prune_layers); + for (const int32_t * p = params->prune_layers; * p != -1; p++) { + prune_list.push_back(* p); + } } // copy the KV pairs from the input file @@ -910,20 +936,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); if (params->kv_overrides) { - const std::vector & overrides = *(const std::vector *)params->kv_overrides; - for (const auto & o : overrides) { - if (o.key[0] == 0) break; - if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { + for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) { + if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { + gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) { // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context - gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64)); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { - gguf_set_val_str(ctx_out.get(), o.key, o.val_str); + gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64)); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { + gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out.get(), o->key, o->val_str); } else { - LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); + LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key); } } } @@ -961,6 +985,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } + // compute tensor metadata once and cache it + std::vector metadata(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + metadata[i].name = ggml_get_name(tensors[i]->tensor); + } + + // initialize quantization state counters and metadata categories + init_quantize_state_counters(qs, metadata); + int idx = 0; uint16_t n_split = 1; @@ -973,25 +1006,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector ctx_outs(n_split); ctx_outs[0] = std::move(ctx_out); - // compute tensor metadata once and cache it - std::vector metadata(tensors.size()); - - // initialize quantization state before preliminary loop (counters for use_more_bits) - { - for (size_t i = 0; i < tensors.size(); ++i) { - const auto cat = tensor_get_category(tensors[i]->tensor->name); - if (category_is_attn_v(cat)) { - ++qs.n_attention_wv; - } - if (cat == tensor_category::OUTPUT) { - qs.has_tied_embeddings = false; - } - metadata[i].category = cat; // save and re-use the category while we're at it - } - // these also need to be set to n_layer by default - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; - } - // flag for --dry-run bool will_require_imatrix = false; @@ -1002,7 +1016,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (size_t i = 0; i < tensors.size(); ++i) { const auto * it = tensors[i]; const struct ggml_tensor * tensor = it->tensor; - const std::string name = ggml_get_name(tensor); uint16_t i_split = params->keep_split ? it->idx : 0; if (!ctx_outs[i_split]) { @@ -1031,7 +1044,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: " - offending tensor: %s\n" " - target type: %s\n" "============================================================================\n\n", - name.c_str(), ggml_type_name(metadata[i].target_type)); + metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type)); throw std::runtime_error("this quantization requires an imatrix!"); } } @@ -1104,7 +1117,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_ofstream(weight.idx); } - const std::string name = ggml_get_name(tensor); const size_t tensor_size = ggml_nbytes(tensor); if (!params->dry_run) { @@ -1235,9 +1247,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: total_size_new += new_size; // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data); // write tensor data + padding fout.write((const char *) new_data, new_size); @@ -1271,7 +1283,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_quantize_params llama_model_quantize_default_params() { llama_model_quantize_params result = { /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q8_0, /*.output_tensor_type =*/ GGML_TYPE_COUNT, /*.token_embedding_type =*/ GGML_TYPE_COUNT, /*.allow_requantize =*/ false, @@ -1302,3 +1314,89 @@ uint32_t llama_model_quantize( return 0; } + +// +// Helper functions for external tools exposed in llama-ext.h +// + +quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params) { + return new quantize_state_impl(*model, params); +} + +void llama_quant_free(quantize_state_impl * qs) { + delete qs; +} + +llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) { + struct llama_model_params mparams = llama_model_default_params(); + auto * model = new llama_model(mparams); + + model->arch = llm_arch_from_string(desc->architecture); + + // infer llm_type: only LLM_TYPE_70B matters for quantization logic + if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) { + model->type = LLM_TYPE_70B; + } + + model->hparams.n_embd = desc->n_embd; + model->hparams.n_embd_head_k_full = desc->n_embd_head_k; + model->hparams.n_embd_head_v_full = desc->n_embd_head_v; + model->hparams.n_layer = desc->n_layer; + model->hparams.n_expert = desc->n_expert; + + for (uint32_t i = 0; i < desc->n_layer; i++) { + model->hparams.n_head_arr[i] = desc->n_head; + model->hparams.n_head_kv_arr[i] = desc->n_head_kv; + model->hparams.n_ff_arr[i] = desc->n_ff; + } + + return model; +} + +bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor) { + return tensor_allows_quantization(qs->params, qs->model.arch, tensor); +} + +void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors) { + // reset per-computation state + qs->n_attention_wv = 0; + qs->n_ffn_down = 0; + qs->n_ffn_gate = 0; + qs->n_ffn_up = 0; + qs->i_attention_wv = 0; + qs->i_ffn_down = 0; + qs->i_ffn_gate = 0; + qs->i_ffn_up = 0; + qs->n_fallback = 0; + qs->has_imatrix = false; + qs->has_tied_embeddings = true; + + // build metadata from tensor names + std::vector metadata(n_tensors); + for (size_t i = 0; i < n_tensors; i++) { + metadata[i].name = ggml_get_name(tensors[i]); + } + + // initialize counters and categories + init_quantize_state_counters(*qs, metadata); + + // use a local copy of params with the requested ftype + llama_model_quantize_params local_params = *qs->params; + local_params.ftype = ftype; + + ggml_type default_type = llama_ftype_get_default_type(ftype); + + // compute types + for (size_t i = 0; i < n_tensors; i++) { + result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]); + } +} diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 68ba292d426..163f222ef61 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GEMMA4: + // Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the + // normalizer, then BPE merges run on the whole text without + // word-level pre-splitting. We only need to split on newlines + // since BPE merge lookup asserts no newlines in tokens. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { } std::vector regex_exprs; + bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8) }; struct llm_tokenizer_bpe_session { @@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); symbols_final.clear(); + auto tok_pre = vocab.get_pre_type(); for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); @@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session { if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); + } else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) { + // fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343 + auto tok = vocab.text_to_token(word); + if (tok != LLAMA_TOKEN_NULL) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } } while (offset < word.size()) { @@ -640,8 +659,17 @@ struct llm_tokenizer_bpe_session { if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { - std::string byte_str(1, *j); - auto token_multibyte = vocab.text_to_token(byte_str); + llama_token token_multibyte = LLAMA_TOKEN_NULL; + if (tokenizer.byte_encode) { + std::string byte_str(1, *j); + token_multibyte = vocab.text_to_token(byte_str); + } else { + // For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format + static const char * hex = "0123456789ABCDEF"; + const uint8_t ch = (uint8_t)*j; + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; + token_multibyte = vocab.text_to_token(buf); + } if (token_multibyte != LLAMA_TOKEN_NULL) { output.push_back(token_multibyte); } @@ -1863,6 +1891,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "gemma4") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + } + + // default special tokens (to be read from GGUF) + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + + tokenizer_pre = "gemma4"; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1870,6 +1934,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, only BPE models have pre-tokenizers if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; + escape_whitespaces = false; clean_spaces = true; if (tokenizer_pre.empty()) { LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); @@ -1936,6 +2001,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; + } else if ( + tokenizer_pre == "gemma4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; + escape_whitespaces = true; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -1952,7 +2021,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen" || - tokenizer_pre == "kormo") { + tokenizer_pre == "kormo" || + tokenizer_pre == "f2llmv2") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( @@ -2129,19 +2199,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { throw std::runtime_error("cannot find tokenizer vocab in model file\n"); } + const uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + const float * scores = nullptr; const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); if (score_idx != -1) { + const uint32_t n_scores = gguf_get_arr_n(ctx, score_idx); + if (n_scores < n_tokens) { + throw std::runtime_error("Index out of array bounds for scores (" + std::to_string(n_scores) + " < " + std::to_string(n_tokens) + ")\n"); + } scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx != -1) { + const uint32_t n_toktypes = gguf_get_arr_n(ctx, toktype_idx); + if (n_toktypes < n_tokens) { + throw std::runtime_error("Index out of array bounds for toktypes (" + std::to_string(n_toktypes) + " < " + std::to_string(n_tokens) + ")\n"); + } toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); } - uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); id_to_token.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { @@ -2255,6 +2334,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) { add_sep = temp; } + + // workaround for Gemma 4 + // ref: https://github.com/ggml-org/llama.cpp/pull/21500 + if (pre_type == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && !add_bos) { + add_bos = true; + + LLAMA_LOG_WARN("%s: override '%s' to 'true' for Gemma4\n", __func__, kv(LLM_KV_TOKENIZER_ADD_BOS).c_str()); + } } // auto-detect special tokens by text @@ -2480,6 +2567,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "[EOS]" // Kimi-K2 || t.first == "<|end_of_text|>" || t.first == "" // smoldocling + || t.first == "" // gemma4 + || t.first == "" // gemma4 + || t.first == "<|tool_response>" // gemma4 + || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2564,6 +2655,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } + + // workaround for gemma4 and paddleocr: do not include as an eog token + { + bool has_tool_response = false; + bool has_s = false; + + llama_token s_id = LLAMA_TOKEN_NULL; + + for (auto tid : special_eog_ids) { + const auto & text = id_to_token[tid].text; + if (text == "<|tool_response>") { + has_tool_response = true; + } else if (text == "") { + has_s = true; + s_id = tid; + } + } + + if (has_tool_response && has_s) { + special_eog_ids.erase(s_id); + + auto & attr = id_to_token[s_id].attr; + attr = LLAMA_TOKEN_ATTR_NORMAL; + + LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '' token from EOG list\n", __func__); + } + } } // build special tokens cache @@ -2732,7 +2850,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const { return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); + // Gemma4 uses BPE with SPM-style byte fallback tokens (<0xXX>) + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_WPM: { GGML_ABORT("fatal error"); @@ -3021,6 +3141,10 @@ std::vector llama_vocab::impl::tokenize( if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + if (escape_whitespaces) { + llama_escape_whitespace(text); + } + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif @@ -3200,9 +3324,19 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t return _try_copy(token_text.data(), token_text.size()); } if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + if (escape_whitespaces) { + // SPM-style BPE: tokens contain ▁ for spaces + std::string result = token_text; + llama_unescape_whitespace(result); + return _try_copy(result.data(), result.size()); + } std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { + char byte = (char) token_to_byte(token); + return _try_copy((char*) &byte, 1); + } break; } case LLAMA_VOCAB_TYPE_RWKV: { @@ -3630,9 +3764,7 @@ int llama_vocab::max_token_len() const { int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right)); if (it == pimpl->bpe_ranks.end()) { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index be5b08012df..dd38f45d3a2 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -58,6 +58,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 872e659edca..e9c3028585d 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -1,6 +1,5 @@ #include "llama.h" -#include "ggml-cpp.h" #include "llama-impl.h" #include "llama-chat.h" @@ -12,6 +11,7 @@ #include "llama-model.h" #include "ggml.h" +#include "ggml-cpp.h" #include "ggml-backend.h" #include "gguf.h" @@ -24,6 +24,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -45,722 +46,6 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } -struct llama_device_memory_data { - int64_t total; - int64_t free; - llama_memory_breakdown_data mb; -}; - -static std::vector llama_get_device_memory_data( - const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, - std::vector & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert, - const ggml_log_level log_level) { - struct user_data_t { - struct { - ggml_log_callback callback; - void * user_data; - } original_logger; - ggml_log_level min_level; // prints below this log level go to debug log - }; - user_data_t ud; - llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); - ud.min_level = log_level; - - llama_log_set([](ggml_log_level level, const char * text, void * user_data) { - const user_data_t * ud = (const user_data_t *) user_data; - const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; - ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); - }, &ud); - - llama_model_params mparams_copy = *mparams; - mparams_copy.no_alloc = true; - mparams_copy.use_mmap = false; - mparams_copy.use_mlock = false; - - llama_model * model = llama_model_load_from_file(path_model, mparams_copy); - if (model == nullptr) { - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to load model"); - } - - llama_context * ctx = llama_init_from_model(model, *cparams); - if (ctx == nullptr) { - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to create llama_context from model"); - } - - std::vector ret(model->devices.size()); - - std::map memory_breakdown = ctx->memory_breakdown(); - - for (const auto & [buft, mb] : memory_breakdown) { - if (ggml_backend_buft_is_host(buft)) { - continue; - } - - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - continue; - } - for (size_t i = 0; i < ret.size(); i++) { - if (model->devices[i] == dev) { - ret[i].mb.model += mb.model; - ret[i].mb.context += mb.context; - ret[i].mb.compute += mb.compute; - break; - } - } - } - for (size_t i = 0; i < ret.size(); i++) { - size_t free; - size_t total; - ggml_backend_dev_memory(model->devices[i], &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - ret[i].free = free; - ret[i].total = total; - } - - devs = model->devices; - hp_ngl = model->hparams.n_layer; - hp_n_ctx_train = model->hparams.n_ctx_train; - hp_n_expert = model->hparams.n_expert; - - llama_memory_breakdown_print(ctx); // goes to debug log - - llama_free(ctx); - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - return ret; -} - -// enum to identify part of a layer for distributing its tensors: -enum layer_fraction_t { - LAYER_FRACTION_NONE = 0, // nothing - LAYER_FRACTION_ATTN = 1, // attention - LAYER_FRACTION_UP = 2, // attention + up - LAYER_FRACTION_GATE = 3, // attention + up + gate - LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights -}; -// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue - -class llama_params_fit_exception : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -static void llama_params_fit_impl( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { - constexpr int64_t MiB = 1024*1024; - typedef std::vector dmds_t; - const llama_model_params default_mparams = llama_model_default_params(); - - std::vector devs; - uint32_t hp_ngl = 0; // hparams.n_gpu_layers - uint32_t hp_nct = 0; // hparams.n_ctx_train - uint32_t hp_nex = 0; // hparams.n_expert - - // step 1: get data for default parameters and check whether any changes are necessary in the first place - - LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); - const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - const size_t nd = devs.size(); // number of devices - if (nd == 0) { - LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); - return; - } - - std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits - margins.reserve(nd); - for (size_t id = 0; id < nd; id++) { - margins.push_back(margins_s[id]); - } - - std::vector dev_names; - { - dev_names.reserve(nd); - size_t max_length = 0; - for (ggml_backend_dev_t dev : devs) { - std::string name = ggml_backend_dev_name(dev); - name += " ("; - name += ggml_backend_dev_description(dev); - name += ")"; - dev_names.push_back(name); - max_length = std::max(max_length, name.length()); - } - for (std::string & dn : dev_names) { - dn.insert(dn.end(), max_length - dn.length(), ' '); - } - } - - int64_t sum_free = 0; - int64_t sum_projected_free = 0; - int64_t sum_projected_used = 0; - int64_t sum_projected_model = 0; - std::vector projected_free_per_device; - projected_free_per_device.reserve(nd); - - if (nd > 1) { - LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); - } - for (size_t id = 0; id < nd; id++) { - const llama_device_memory_data & dmd = dmds_full[id]; - - const int64_t projected_used = dmd.mb.total(); - const int64_t projected_free = dmd.free - projected_used; - projected_free_per_device.push_back(projected_free); - - sum_free += dmd.free; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - sum_projected_model += dmd.mb.model; - - if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); - } - } - assert(sum_free >= 0 && sum_projected_used >= 0); - LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_free/MiB); - if (nd == 1) { - if (projected_free_per_device[0] >= margins[0]) { - LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); - return; - } - } else { - bool changes_needed = false; - for (size_t id = 0; id < nd; id++) { - if (projected_free_per_device[id] < margins[id]) { - changes_needed = true; - break; - } - } - if (!changes_needed) { - LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); - return; - } - } - - // step 2: try reducing memory use by reducing the context size - - { - int64_t global_surplus = sum_projected_free; - for (size_t id = 0; id < nd; id++) { - global_surplus -= margins[id]; - } - if (global_surplus < 0) { - if (nd == 1) { - LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", - __func__, margins[0]/MiB, -global_surplus/MiB); - } else { - LLAMA_LOG_INFO( - "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, -global_surplus/MiB); - } - if (cparams->n_ctx == 0) { - if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free; - for (size_t id = 0; id < nd; id++) { - sum_used_target -= margins[id]; - } - if (nd > 1) { - // for multiple devices we need to be more conservative in terms of how much context we think can fit: - // - for dense models only whole layers can be assigned to devices - // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer - // - on average we expect a waste of 0.5 layers/tensors per device - // - use slightly more than the expected average for nd devices to be safe - const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); - sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); - } - - int64_t sum_projected_used_min_ctx = 0; - cparams->n_ctx = n_ctx_min; - const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const auto & dmd : dmds_min_ctx) { - sum_projected_used_min_ctx += dmd.mb.total(); - } - if (sum_used_target > sum_projected_used_min_ctx) { - // linear interpolation between minimum and maximum context size: - cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) - / (sum_projected_used - sum_projected_used_min_ctx); - cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend - - const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); - const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (nd == 1) { - LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); - return; - } - LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); - } else { - const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - } - } else { - if (n_ctx_min == UINT32_MAX) { - LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); - } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); - } - } - } else { - LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); - } - } - } - - if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { - throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); - } - if (nd > 1) { - if (!tensor_split) { - throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); - } - if (mparams->tensor_split) { - for (size_t id = 0; id < nd; id++) { - if (mparams->tensor_split[id] != 0.0f) { - throw llama_params_fit_exception("model_params::tensor_split already set by user, abort"); - } - } - } - if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { - throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); - } - } - if (!tensor_buft_overrides) { - throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); - } - if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { - throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); - } - - // step 3: iteratively fill the back to front with "dense" layers - // - for a dense model simply fill full layers, giving each device a contiguous slice of the model - // - for a MoE model, same as dense model but with all MoE tensors in system memory - - // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: - auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * { - constexpr size_t n_strings = 1000; - if (il >= n_strings) { - throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); - } - switch (lf) { - case LAYER_FRACTION_ATTN: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_UP: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_GATE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_MOE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps"; - } - return patterns[il].c_str(); - } - default: - GGML_ABORT("fatal error"); - } - }; - - struct ngl_t { - uint32_t n_layer = 0; // number of total layers - uint32_t n_part = 0; // number of partial layers, <= n_layer - - // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: - layer_fraction_t overflow_type = LAYER_FRACTION_MOE; - - uint32_t n_full() const { - assert(n_layer >= n_part); - return n_layer - n_part; - } - }; - - const size_t ntbo = llama_max_tensor_buft_overrides(); - - // utility function to set n_gpu_layers and tensor_split - auto set_ngl_tensor_split_tbo = [&]( - const std::vector & ngl_per_device, - const std::vector & overflow_bufts, - llama_model_params & mparams) { - mparams.n_gpu_layers = 0; - for (size_t id = 0; id < nd; id++) { - mparams.n_gpu_layers += ngl_per_device[id].n_layer; - if (nd > 1) { - tensor_split[id] = ngl_per_device[id].n_layer; - } - } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); - uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - - mparams.tensor_split = tensor_split; - - size_t itbo = 0; - for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_full(); - for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { - if (itbo + 1 >= ntbo) { - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == " - + std::to_string(ntbo) + " is insufficient for model"); - } - tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); - itbo++; - } - il0 += ngl_per_device[id].n_part; - } - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - }; - - // utility function that returns the memory use per device for given numbers of layers per device - auto get_memory_for_layers = [&]( - const char * func_name, - const std::vector & ngl_per_device, - const std::vector & overflow_bufts) -> std::vector { - llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); - - const dmds_t dmd_nl = llama_get_device_memory_data( - path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name); - for (size_t id = 0; id < nd; id++) { - const ngl_t & n = ngl_per_device[id]; - LLAMA_LOG_DEBUG( - "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", - func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); - } - - std::vector ret; - ret.reserve(nd); - for (const llama_device_memory_data & dmd : dmd_nl) { - ret.push_back(dmd.mb.total()); - } - return ret; - }; - - int64_t global_surplus_cpu_moe = 0; - if (hp_nex > 0) { - const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors - ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); - tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; - tensor_buft_overrides[1] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - - LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); - const dmds_t dmds_cpu_moe = llama_get_device_memory_data( - path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - for (size_t id = 0; id < nd; id++) { - global_surplus_cpu_moe += dmds_cpu_moe[id].free; - global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; - } - - if (global_surplus_cpu_moe > 0) { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", - __func__, global_surplus_cpu_moe/MiB); - } else { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", - __func__, -global_surplus_cpu_moe/MiB); - } - - // reset - tensor_buft_overrides[0] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - } - - std::vector targets; // maximum acceptable memory use per device - targets.reserve(nd); - for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margins[id]); - LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); - } - - std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: - overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd; id++) { - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); - } - - std::vector ngl_per_device(nd); - std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - - // optimize the number of layers per device using the method of false position: - // - ngl_per_device has 0 layers for each device, lower bound - // - try a "high" configuration where a device is given all unassigned layers - // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target - // - check memory use of our guess, replace either the low or high bound - // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits - // - the last device has the output layer, which cannot be a partial layer - if (hp_nex == 0) { - LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); - } else { - LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); - } - for (int id = nd - 1; id >= 0; id--) { - uint32_t n_unassigned = hp_ngl + 1; - for (size_t jd = id + 1; jd < nd; ++jd) { - assert(n_unassigned >= ngl_per_device[jd].n_layer); - n_unassigned -= ngl_per_device[jd].n_layer; - } - - std::vector ngl_per_device_high = ngl_per_device; - ngl_per_device_high[id].n_layer = n_unassigned; - if (hp_nex > 0) { - ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; - } - if (ngl_per_device_high[id].n_layer > 0) { - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); - uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - ngl_per_device_test[id].n_layer += step_size; - if (hp_nex) { - ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? - step_size - 1 : step_size; // the first layer is the output layer which must always be full - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); - } - delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - } - } else { - assert(ngl_per_device_high[id].n_layer == n_unassigned); - ngl_per_device = ngl_per_device_high; - mem = mem_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); - } - if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); - return; - } - - // step 4: for a MoE model where all dense tensors fit, - // convert the dense-only layers in the back to full layers in the front until all devices are full - // essentially the same procedure as for the dense-only layers except front-to-back - // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM - - size_t id_dense_start = nd; - for (int id = nd - 1; id >= 0; id--) { - if (ngl_per_device[id].n_layer > 0) { - id_dense_start = id; - continue; - } - break; - } - assert(id_dense_start < nd); - - LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { - std::vector ngl_per_device_high = ngl_per_device; - for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; - ngl_per_device_high[id].n_layer += n_layer_move; - ngl_per_device_high[jd].n_layer -= n_layer_move; - ngl_per_device_high[jd].n_part = 0; - } - size_t id_dense_start_high = nd - 1; - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - uint32_t n_converted_test = 0; - for (;id_dense_start_test < nd; id_dense_start_test++) { - const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); - ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; - ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; - ngl_per_device_test[id].n_layer += n_convert_jd; - n_converted_test += n_convert_jd; - - if (ngl_per_device_test[id_dense_start_test].n_part > 0) { - break; - } - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - id_dense_start_high = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", - __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); - } - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - } - } else { - ngl_per_device = ngl_per_device_high; - mem = mem_high; - id_dense_start = id_dense_start_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - - // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - ngl_per_device_test[id_dense_start_test].n_layer--; - ngl_per_device_test[id_dense_start_test].n_part--; - ngl_per_device_test[id].n_layer++; - ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_part == 0) { - id_dense_start_test++; - } - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; - std::vector overflow_bufts_test = overflow_bufts; - if (id < nd - 1) { - overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); - } - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } else { - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - // print info for devices that were not changed during the conversion from dense only to full layers: - for (size_t id = id_dense_start + 1; id < nd; id++) { - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); -} - -enum llama_params_fit_status llama_params_fit( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { - const int64_t t0_us = llama_time_us(); - llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; - try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); - LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); - } catch (const llama_params_fit_exception & e) { - LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_FAILURE; - } catch (const std::runtime_error & e) { - LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_ERROR; - } - const int64_t t1_us = llama_time_us(); - LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); - return status; -} - struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -828,7 +113,7 @@ int64_t llama_time_us(void) { // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, - const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { + const std::string & fname, std::vector & splits, FILE * file, llama_model & model, llama_model_params & params) { // loading time will be recalculated after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = 0; @@ -837,7 +122,7 @@ static int llama_model_load(struct gguf_context * metadata, llama_model_set_tens model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, params.use_mmap, params.use_direct_io, + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); @@ -889,8 +174,24 @@ static struct llama_model * llama_model_load_from_file_impl( void * set_tensor_data_ud, const std::string & path_model, std::vector & splits, + FILE * file, struct llama_model_params params) { - GGML_ASSERT((metadata == nullptr) != path_model.empty() && "exactly one out of metadata and path_model needs to be defined"); + { + int n_sources_defined = 0; + if (metadata != nullptr) { + n_sources_defined++; + } + if (!path_model.empty()) { + n_sources_defined++; + } + if (file != nullptr) { + n_sources_defined++; + } + if (n_sources_defined != 1) { + LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); + return nullptr; + } + } ggml_time_init(); if (!params.vocab_only && ggml_backend_reg_count() == 0) { @@ -919,58 +220,111 @@ static struct llama_model * llama_model_load_from_file_impl( // create list of devices to use with this model if (params.devices) { - for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { - model->devices.push_back(*dev); + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + size_t n_devs = 0; + while (params.devices[n_devs]) { + n_devs++; + } + if (n_devs == 0) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return nullptr; + } + LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs); + for (size_t i = 0; i < n_devs; ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s\n", __func__, i, ggml_backend_dev_name(params.devices[i])); + } + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.model = model; + model->devices.push_back({ + true, ggml_backend_meta_device( + params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { + model->devices.push_back({false, *dev}); + } } } else { // default device selection // build list of available devices - std::vector gpus; - std::vector igpus; - std::vector rpc_servers; - - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - switch (ggml_backend_dev_type(dev)) { - case GGML_BACKEND_DEVICE_TYPE_CPU: - case GGML_BACKEND_DEVICE_TYPE_ACCEL: - // skip CPU backends since they are handled separately - break; - - case GGML_BACKEND_DEVICE_TYPE_GPU: { - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - if (ggml_backend_reg_name(reg) == std::string("RPC")) { - rpc_servers.push_back(dev); - } else { - // check if there is already a GPU with the same device id - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { - ggml_backend_dev_props d_props; - ggml_backend_dev_get_props(d, &d_props); - if (props.device_id && d_props.device_id) { - return strcmp(props.device_id, d_props.device_id) == 0; - } - return false; - }); - - if (it != gpus.end()) { - LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", - __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - props.device_id ? props.device_id : "unknown id", - ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + std::vector gpus; + std::vector igpus; + std::vector rpc_servers; + + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + std::vector devs; + devs.reserve(ggml_backend_dev_count()); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_buffer_type(dev) == ggml_backend_cpu_buffer_type()) { + LLAMA_LOG_INFO("%s: skipping %s (%s) for tensor parallelism\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); + continue; + } + devs.push_back(dev); + } + if (devs.empty()) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return nullptr; + } + + LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size()); + for (size_t i = 0; i < devs.size(); ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s (%s)\n", __func__, i, ggml_backend_dev_name(devs[i]), ggml_backend_dev_description(devs[i])); + } + + GGML_ASSERT(!devs.empty()); + model->get_split_state_ud.n_devices = devs.size(); + model->get_split_state_ud.model = model; + gpus.push_back({ + true, ggml_backend_meta_device( + devs.data(), devs.size(), llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + switch (ggml_backend_dev_type(dev)) { + case GGML_BACKEND_DEVICE_TYPE_CPU: + case GGML_BACKEND_DEVICE_TYPE_ACCEL: + // skip CPU backends since they are handled separately + break; + + case GGML_BACKEND_DEVICE_TYPE_GPU: { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_servers.push_back({false, dev}); } else { - gpus.push_back(dev); + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](const llama_device & d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d.dev, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); + + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(it->dev), ggml_backend_dev_description(it->dev)); + } else { + gpus.push_back({false, dev}); + } } + break; } - break; - } - case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back(dev); - break; + case GGML_BACKEND_DEVICE_TYPE_IGPU: + igpus.push_back({false, dev}); + break; + case GGML_BACKEND_DEVICE_TYPE_META: + GGML_ABORT("fatal error"); + } } } @@ -996,22 +350,22 @@ static struct llama_model * llama_model_load_from_file_impl( llama_model_free(model); return nullptr; } - ggml_backend_dev_t main_gpu = model->devices[params.main_gpu]; + llama_device main_gpu = model->devices[params.main_gpu]; model->devices.clear(); model->devices.push_back(main_gpu); } } - for (auto * dev : model->devices) { + for (const auto & dev : model->devices) { ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); + ggml_backend_dev_get_props(dev.dev, &props); LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + ggml_backend_dev_name(dev.dev), ggml_backend_dev_description(dev.dev), props.device_id ? props.device_id : "unknown id", props.memory_free/1024/1024); } - const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, *model, params); + const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -1037,7 +391,7 @@ struct llama_model * llama_model_init_from_user( std::vector splits = {}; params.use_mmap = false; params.use_extra_bufts = false; - return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, params); + return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, /*file*/ nullptr, params); } // deprecated struct llama_model * llama_load_model_from_file( @@ -1050,7 +404,7 @@ struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params) { std::vector splits = {}; - return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, /*file*/ nullptr, params); } struct llama_model * llama_model_load_from_splits( @@ -1066,7 +420,17 @@ struct llama_model * llama_model_load_from_splits( for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } - return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, /*file*/ nullptr, params); +} + +struct llama_model * llama_model_load_from_file_ptr(FILE * file, struct llama_model_params params) { + if (!file) { + LLAMA_LOG_ERROR("%s: file is NULL\n", __func__); + return nullptr; + } + std::string path_model; + std::vector splits = {}; + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, file, params); } void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index c6e102abe51..eb869814097 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -154,6 +154,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -191,9 +192,10 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_TENSOR = 3, }; // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -380,22 +382,33 @@ extern "C" { size_t n_samplers; }; + struct llama_model_tensor_override { + const char * pattern; + enum ggml_type type; + }; + + struct llama_model_imatrix_data { + const char * name; + const float * data; + size_t size; + }; + // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - bool dry_run; // calculate and show the final quantization size without performing quantization - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides - void * tensor_types; // pointer to vector containing tensor types - void * prune_layers; // pointer to vector containing layer indices to prune + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization + const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data + const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides + const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides + const int32_t * prune_layers; // pointer to layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -465,6 +478,11 @@ extern "C" { const char * path_model, struct llama_model_params params); + // Load a model from an open FILE pointer + LLAMA_API struct llama_model * llama_model_load_from_file_ptr( + FILE * file, + struct llama_model_params params); + // Load a model from multiple splits (support custom naming scheme) // The paths must be in the correct order LLAMA_API struct llama_model * llama_model_load_from_splits( @@ -493,27 +511,6 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); - enum llama_params_fit_status { - LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path - }; - - // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // - returns true if the parameters could be successfully modified to fit device memory - // - this function is NOT thread safe because it modifies the global llama logger state - // - only parameters that have the same value as in llama_default_model_params are modified - // with the exception of the context size which is modified if and only if equal to 0 - LLAMA_API enum llama_params_fit_status llama_params_fit( - const char * path_model, - struct llama_model_params * mparams, - struct llama_context_params * cparams, - float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements - struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t * margins, // margins of memory to leave per device in bytes - uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use - enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log - LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); @@ -636,7 +633,6 @@ extern "C" { // Load a LoRA adapter from file // The adapter is valid as long as the associated model is not freed - // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); @@ -660,9 +656,8 @@ extern "C" { LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); // Manually free a LoRA adapter - // NOTE: loaded adapters will be free when the associated model is deleted - LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter), - "adapters are now freed together with the associated model"); + // NOTE: loaded adapters that are not manually freed will be freed when the associated model is deleted + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); // Get the invocation tokens if the current lora is an alora LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); @@ -1530,9 +1525,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - // print a breakdown of per-device memory use via LLAMA_LOG: - LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); - // // training // diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 9aabe25c965..2790b12111d 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -41,22 +41,13 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para { ggml_tensor * attn_inp = cur; // save input for gate computation - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // compute gate from input ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp); cb(gate, "attn_gate_proj", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // Q/K normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -77,10 +68,8 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, - NULL, NULL, // wo will be applied after gating + NULL, NULL, NULL, // wo will be applied after gating Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -91,7 +80,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(cur, "attn_gated", il); // now apply output projection - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_o_proj", il); } diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 4d65614e466..af44cea6054 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -32,25 +30,15 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +50,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 20b9ffd49eb..2e71f5d9e2a 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -36,30 +35,8 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -78,7 +55,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index b712e08cbd3..f8ca6aff6ab 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -30,18 +30,8 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +50,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index abd03cd0b97..2d0d05df485 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -29,18 +28,8 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); switch (model.type) { case LLM_TYPE_7B: @@ -67,7 +56,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 25e3369c313..67a7120d622 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -28,30 +28,8 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head_k, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -70,7 +48,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 42098624663..497b4babd0c 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -3,7 +3,6 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -29,15 +28,8 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll // self_attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -56,7 +48,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 87331791418..7e046cfd2a4 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -2,7 +2,6 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -28,8 +27,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); auto * inp_attn = build_attn_inp_no_cache(); @@ -39,35 +38,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur = inpL; { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head * n_head, n_tokens); @@ -100,7 +72,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index ccf5bc8e82b..71526354ca6 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -28,33 +28,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - // B1.K - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - // B1.V - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -73,7 +48,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - NULL, NULL, + NULL, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, @@ -82,8 +57,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "attn_sub_norm", il); cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + if (model.layers[il].wo_b) { + cur = ggml_add(ctx0, cur, model.layers[il].wo_b); } cb(cur, "attn_out", il); } @@ -121,6 +96,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b1c19bb58a2..f3b0999bf54 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -2,7 +2,6 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -16,8 +15,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -30,22 +29,11 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 2f24105fa14..21deaba1a6d 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -36,22 +36,10 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { - Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur) * n_embd_head, - ggml_element_size(Qcur) * n_embd_head * n_head, - 0); - cb(Qcur, "Qcur", il); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -60,12 +48,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr } if (model.layers[il].attn_k_norm) { - Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, - ggml_element_size(Kcur) * n_embd_head, - ggml_element_size(Kcur) * n_embd_head * n_head_kv, - 0); - cb(Kcur, "Kcur", il); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -73,10 +55,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Kcur, "Kcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -94,7 +72,7 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 5887ed22e7e..7d4a43fdca5 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -3,7 +3,6 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -30,37 +29,8 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( @@ -80,7 +50,7 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -111,8 +81,13 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e8e13e143f2..3ceb5835b85 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -2,7 +2,6 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -28,15 +27,8 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +47,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 2ef2b6e389b..be3eeeddac7 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -28,18 +28,20 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa for (int il = 0; il < n_layer; ++il) { // get either the text or image weight tensors - ggml_tensor *wqkv, *wo; + ggml_tensor *wqkv, *wo, *wo_s; ggml_tensor *ffn_gate, *ffn_down, *ffn_up; if (is_text) { wqkv = model.layers[il].wqkv; wo = model.layers[il].wo; + wo_s = model.layers[il].wo_s; ffn_gate = model.layers[il].ffn_gate; ffn_down = model.layers[il].ffn_down; ffn_up = model.layers[il].ffn_up; } else { wqkv = model.layers[il].visexp_attn_wqkv; wo = model.layers[il].visexp_attn_wo; + wo_s = nullptr; ffn_gate = model.layers[il].visexp_ffn_gate; ffn_down = model.layers[il].visexp_ffn_down; ffn_up = model.layers[il].visexp_ffn_up; @@ -64,7 +66,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa Kcur = ggml_rope(ctx0, Kcur, inp_pos, n_embd_head, rope_type); cur = build_attn(inp_attn, - wo, nullptr, + wo, nullptr, wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); @@ -86,6 +88,10 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2-iswa.cpp index 7c71a59ae7f..670b08e7d97 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2-iswa.cpp @@ -36,30 +36,8 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (is_swa) { Qcur = ggml_rope_ext( @@ -80,7 +58,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index ba1230f0419..067961caa08 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -32,27 +32,8 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM, il); @@ -73,7 +54,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 73eb5cd24e7..0e882721807 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -2,7 +2,6 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -30,19 +29,8 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +49,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index ac448bfcaa8..30272eabd69 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -47,27 +45,8 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -80,7 +59,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 3432359e03a..671b72dfead 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -35,27 +35,8 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -68,7 +49,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index d437fe29e71..303fc72c610 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -2,6 +2,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_ocr = model.arch == LLM_ARCH_DEEPSEEK2OCR; + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA @@ -54,7 +57,38 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn_kv, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + else { ggml_tensor * q = NULL; const bool is_lite = model.layers[il].wq; @@ -148,7 +182,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn_k, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); @@ -185,7 +219,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn_kv, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 07236dd27c9..5d1750fedda 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -29,18 +29,8 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -59,7 +49,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 4edc8530cb3..8e7d9ae64c7 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 @@ -31,22 +29,8 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -59,7 +43,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 63baf152c40..fc6a3e17a09 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -30,27 +30,8 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -63,7 +44,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index d548de0547b..033ba409eab 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -29,27 +29,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +43,7 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index e8628d165d0..43fff4daf3a 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -24,17 +24,8 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - Qcur = build_lora_mm(model.layers[il].wq, cur); - Kcur = build_lora_mm(model.layers[il].wk, cur); - Vcur = build_lora_mm(model.layers[il].wv, cur); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -53,7 +44,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -82,6 +73,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap cur = ggml_add(ctx0, cur, ffn_inp); + // input for next layer inpL = cur; } cur = inpL; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index ea75701c528..7b88a31d39d 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -35,18 +35,8 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -65,7 +55,7 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn_iswa, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index d4eea58e2f1..4f845bf4106 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -34,27 +32,8 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -67,7 +46,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 755af3b747b..34bee3b8fe9 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,6 +1,5 @@ #include "models.h" - template llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -39,18 +38,8 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -69,7 +58,7 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index ff842d93a41..05accf90fad 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -27,19 +27,8 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self-attention - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -52,7 +41,7 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(Vcur, "Vcur-post-rope", il); ggml_tensor * attn_out = build_attn(inp->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 9fcba508878..2f65fa56e1f 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -42,12 +40,8 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cur = attn_norm; } - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -67,7 +61,7 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 98110d45e3b..b6de9551c52 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -9,7 +9,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -31,18 +31,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -65,7 +55,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 1869efd389a..09d2ff8bae7 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -29,18 +28,8 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +49,7 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_scaled", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2-iswa.cpp index 3927ddd297b..0ef07df8d01 100644 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ b/examples/talk-llama/models/gemma2-iswa.cpp @@ -31,18 +31,8 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +51,7 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index bbb4d9a81e8..0da4af21c17 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -9,7 +9,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -47,18 +47,8 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -84,7 +74,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n-iswa.cpp index 8ce2ae39c2f..f8095417e06 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n-iswa.cpp @@ -1,5 +1,12 @@ #include "models.h" +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -12,7 +19,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_iswa(); - // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] - ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); + ggml_tensor * inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); // inpL now has only 1 altup, project it to the rest of the altups // these "added" altups will be concat to the last dim of inpL @@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] cb(inpL, "inp_stacked", -1); } - // inpL now has shape: [n_embd, n_tokens, n_altup] - // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + // inpL now has shape: [n_embd, n_tokens, n_altup] for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code @@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] // predicted value will go through self-attention and laurel - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] - cur = active_prediction; + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens] + cur = active_prediction; cb(cur, "active_prediction", il); // norm @@ -62,19 +71,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // self-attention if (hparams.has_kv(il)) { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -94,7 +91,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Kcur, "Kcur_pos", il); cur = build_attn(inp_attn, model.layers[il].wo, - NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } else { // reuse KV cache of earlier layers @@ -110,7 +107,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Qcur, "Qcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); @@ -151,12 +148,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * first_prediction; // [n_embd, n_tokens] { - first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_gated", il); - ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_scaled", il); @@ -167,7 +165,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const } // equivalent to python code: corrected_predictions[1:] += first_prediction { - ggml_tensor * slice_first = view_2d_slice(corrected, 0); + ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0); ggml_tensor * slice_rest = ggml_view_3d( ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); @@ -185,7 +183,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // cur now has multiple altup(s), we want to merge them back to 1 altup { - ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] + ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens] // do a view to skip the first slice (active altup) ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), @@ -197,9 +195,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(altup_unembd, "altup_unembd", -1); // equivalent to torch.mean(hidden_states, dim=0) - cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] + cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens] for (int i = 0; i < n_altup - 1; ++i) { - cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); + cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i)); } cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] cb(cur, "unembd_merged", -1); @@ -235,39 +233,34 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); } -// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim -ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { - GGML_ASSERT(idx < (int) x->ne[2]); - return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), - idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); -} - // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { +ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_altup); if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); ggml_set_input(inp->tokens); res->t_inp_tokens = inp->tokens; - inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); - inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); cb(inp_per_layer, "inp_per_layer_selected", -1); res->add_input(std::move(inp)); } else { - // Vision embedding path: use padding token (ID=0) embedding + // Multimodal embedding path: use padding token (ID=0) embedding // TODO: verify if this is the correct behavior in transformers implementation - const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer // Extract and dequantize padding token embedding (row 0) - ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); - inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32); + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); // Reshape to [n_embd_altup, n_layer, 1] inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); - cb(inp_per_layer, "inp_per_layer_vision", -1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); } return inp_per_layer; } @@ -275,18 +268,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { // equivalent to project_per_layer_inputs() in python code // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim // output shape: [n_embd_altup, n_tokens, n_layer] -ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { +ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); - ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); - per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); - per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); - per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, - -1); // [n_embd_altup, n_layer, n_tokens] + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1); cb(per_layer_proj, "per_layer_proj", -1); - inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); cb(inp_per_layer, "inp_per_layer", -1); @@ -337,7 +331,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso // input cur shape: [n_embd, n_tokens, n_altup] // output shape: [n_embd, n_tokens, n_altup] ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { - ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] + ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); @@ -365,7 +359,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] cb(innovation, "innovation", il); diff --git a/examples/talk-llama/models/gemma4-iswa.cpp b/examples/talk-llama/models/gemma4-iswa.cpp new file mode 100644 index 00000000000..c7fb7747414 --- /dev/null +++ b/examples/talk-llama/models/gemma4-iswa.cpp @@ -0,0 +1,322 @@ +#include "models.h" + +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + +llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params), + model(model), + n_embd_per_layer(model.hparams.n_embd_per_layer) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inp_per_layer = nullptr; + if (model.per_layer_tok_embd) { + inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); + } + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_embd_head = hparams.n_embd_head_k(il); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il)); + + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * freq_factors = nullptr; + if (!hparams.is_swa(il)) { + // full_attention layers use rope_freqs for proportional rope + freq_factors = model.layers[il].rope_freqs; + } + + // Q projection (shared for both non-KV and KV layers) + // this is to mirror Gemma4Attention in pytorch code + ggml_tensor * Qcur; + { + Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + } + + // self-attention + if (hparams.has_kv(il)) { + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; // if v_proj is not present, use Kcur as Vcur + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); + + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Kcur, "Kcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, + nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + hparams.f_attention_scale, il); + } else { + // reuse KV cache of earlier layers + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } + + // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + // feed-forward network + const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr; + if (is_moe_layer) { + // MLP (shared exp) + ggml_tensor * cur_mlp = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_norm_1", il); + + cur_mlp = build_ffn(cur_mlp, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cur_mlp = build_norm(cur_mlp, + model.layers[il].ffn_post_norm_1, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_mlp", il); + + // Expert FFN + ggml_tensor * cur_moe = build_norm(attn_out, + model.layers[il].ffn_pre_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_norm_2", il); + + // custom MoE logits calculation (router operates on attn_out, not cur) + ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps); + tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd)); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s); + ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + cur_moe = build_moe_ffn(cur_moe, + nullptr, // gate_inp + nullptr, // up_exps + nullptr, // gate_exps + model.layers[il].ffn_down_exps, + nullptr, // exp_probs_b (not used for gemma4) + n_expert, n_expert_used, + LLM_FFN_GELU, true, + 1.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, logits, + model.layers[il].ffn_gate_up_exps, + nullptr, // up_exps_s + nullptr, // gate_exps_s + model.layers[il].ffn_down_exps_s); + cur_moe = build_norm(cur_moe, + model.layers[il].ffn_post_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_moe", il); + + cur = ggml_add(ctx0, cur_mlp, cur_moe); + cb(cur, "ffn_moe_combined", il); + } else { + cur = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, nullptr, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, attn_out); + + // per-layer embedding + if (inp_per_layer) { + ggml_tensor * pe_in = cur; + cb(cur, "pe_in", il); + + cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens] + cur = ggml_gelu(ctx0, cur); + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] + + // TODO @ngxson : improve this + if (il == n_layer - 1 && inp_out_ids) { + inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); + } + + cur = ggml_mul(ctx0, cur, inp_this_layer); + cur = build_lora_mm(model.layers[il].per_layer_proj, cur); // [n_embd, n_tokens] + cur = build_norm(cur, model.layers[il].per_layer_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "per_layer_embd_out", il); + + // residual connection + cur = ggml_add(ctx0, pe_in, cur); + } + + // layer_scalar + if (model.layers[il].out_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// equivalent to get_per_layer_inputs() in python code +// output shape: [n_embd_per_layer, n_layer, n_tokens] +ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { + auto inp = std::make_unique(n_embd); + + ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_per_layer); + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); + cb(inp_per_layer, "inp_per_layer_selected", -1); + + res->add_input(std::move(inp)); + } else { + // Multimodal embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer + + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); + + // Reshape to [n_embd_per_layer, n_layer, 1] + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); + } + return inp_per_layer; +} + +// equivalent to project_per_layer_inputs() in python code +// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim +// inp_batch shape: [n_embd, n_tokens] +// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) +// output shape: [n_embd_per_layer, n_tokens, n_layer] +ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { + const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); + const float per_layer_input_scale = 1.0f / sqrtf(2.0f); + + // note: this matrix multiplication will be performed in the input layer (i.e. on the CPU) + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1); + cb(per_layer_proj, "per_layer_proj", -1); + + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); + cb(inp_per_layer, "inp_per_layer", -1); + + // permute to shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + return inp_per_layer; +} diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 7938545ed8a..8d4f4a01553 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -38,27 +38,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply Q/K norm if available (GLM-4.5 355B variant) if (model.layers[il].attn_q_norm) { @@ -94,7 +75,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index b6ad8febed3..f0bfda393fa 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -1,10 +1,7 @@ #include "models.h" - - llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -41,40 +38,8 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_mrope) { Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, @@ -100,7 +65,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index cb1238f2d34..f8dc53eb723 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -2,7 +2,6 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -34,22 +33,11 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 1c8fe6c836d..0016ddede43 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -28,15 +26,8 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +46,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 9b54a38c386..e983742bef5 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -73,31 +73,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -116,7 +92,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 7a7e1664c29..6ea90285225 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -76,31 +76,8 @@ ggml_tensor * llm_build_granite::build_attention_layer( const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -124,7 +101,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 580d63e36ae..b8f35afdc03 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -30,27 +30,8 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index aa60d3e9388..151108a2a71 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -30,18 +30,8 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -60,7 +50,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 6a51707c85b..1cd85d6d9d4 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; @@ -34,44 +39,39 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); @@ -83,7 +83,7 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 806c30b3667..ffe1664b0e1 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -35,27 +35,8 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -84,7 +65,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 441d250268e..83be2ca0aee 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -30,27 +30,8 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 135bf288ba1..31101f3c14b 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -2,7 +2,6 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -24,22 +23,11 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -66,8 +54,14 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, model.output_norm, diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 2cfe484eb52..507e04fa4aa 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -31,25 +31,8 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para // Self-attention with separate Q, K, V projections { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur_bias", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur_bias", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur_bias", il); - - // Reshape for attention - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply RoPE Qcur = ggml_rope_ext( @@ -68,7 +51,7 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index c0c89de187a..f82b7795c87 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -24,25 +24,12 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para } else { // Attention - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // No RoPE :) cur = build_attn(inp_hybrid->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 4d62f4e7159..58c89c417fc 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -268,7 +268,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * Vcur = kv_cmpr; cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); + cur = build_attn(inp_attn_k, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); cb(cur, "mla_out", il); } else { // MLA KV cache disabled. Fall back to MHA KV cache. Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); @@ -299,7 +299,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Direct softmax attention (with MHA KV cache) // Use build_attn with inp_attn for proper mask handling - cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); + cur = build_attn(inp_attn_kv, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); cb(cur, "mla_out", il); } } @@ -362,6 +362,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = build_cvec(cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } cur = inpL; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index dfa322166b1..eb8ec3c803a 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -42,16 +42,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ const auto n_embd_head = hparams.n_embd_head_v(); const auto n_head_kv = hparams.n_head_kv(il); - auto * q = build_lora_mm(model.layers[il].wq, cur); - cb(q, "model.layers.{}.self_attn.q_proj", il); - auto * k = build_lora_mm(model.layers[il].wk, cur); - cb(k, "model.layers.{}.self_attn.k_proj", il); - auto * v = build_lora_mm(model.layers[il].wv, cur); - cb(v, "model.layers.{}.self_attn.v_proj", il); - - q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + auto [q, k, v] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // qk norm q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); @@ -66,7 +58,7 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ attn_factor, beta_fast, beta_slow); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -177,6 +169,9 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ cb(ffn_norm_out, "model.layers.{}.ffn_out", il); cur = ggml_add(ctx0, cur, ffn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); } cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 18de88fde1f..c756d6fde5f 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -30,18 +30,8 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,7 +56,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 0dac9d616ae..501df3c7eaf 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -30,17 +30,8 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para // self-attention { // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -53,7 +44,7 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index e08ae0c0b0e..8d478dc6747 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -43,27 +43,8 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -89,11 +70,8 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llama-iswa.cpp b/examples/talk-llama/models/llama4.cpp similarity index 81% rename from examples/talk-llama/models/llama-iswa.cpp rename to examples/talk-llama/models/llama4.cpp index 67cb9a10ec5..4e4bfb43f33 100644 --- a/examples/talk-llama/models/llama-iswa.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -18,7 +19,14 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * inp_attn_scale = nullptr; inp_attn_scale = build_inp_attn_scale(); - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -46,27 +54,8 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -95,7 +84,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -176,3 +165,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } + +// Explicit template instantiations +template struct llm_build_llama4; +template struct llm_build_llama4; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index a72b7790a1f..8a76931c007 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -30,18 +30,8 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +56,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/mamba-base.cpp b/examples/talk-llama/models/mamba-base.cpp index 9de587db55f..c37f29c487e 100644 --- a/examples/talk-llama/models/mamba-base.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -42,7 +42,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur); + ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur, layer.ssm_in_s); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -137,7 +137,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(layer.ssm_out, y); + cur = build_lora_mm(layer.ssm_out, y, layer.ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -184,7 +184,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur, model.layers[il].ssm_in_s); // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0], @@ -278,7 +278,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(model.layers[il].ssm_out, y, model.layers[il].ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp index 06956915ea0..52c6acfe214 100644 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ b/examples/talk-llama/models/mimo2-iswa.cpp @@ -58,7 +58,7 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ ggml_tensor * sinks = model.layers[il].attn_sinks; cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); } diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 89dd7105157..bf12ab73c74 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -134,7 +134,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 83d0916c08c..b809b79f2b9 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -64,7 +64,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 42a5117ff02..b5ae72a2ee1 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -41,27 +41,8 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -86,7 +67,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index a86b2b1ebd7..94991c55fe8 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params); ggml_tensor * calc_magnitude(ggml_tensor * x); - ggml_tensor * view_2d_slice(ggml_tensor * x, int idx); - ggml_tensor * get_per_layer_inputs(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + ggml_tensor * gaussian_topk(ggml_tensor * x); ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); ggml_tensor * altup_predict(ggml_tensor * cur, int il); @@ -266,6 +268,18 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); }; +struct llm_build_gemma4_iswa : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_per_layer; + + llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); +}; + struct llm_build_gemma_embedding : public llm_graph_context { llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params); }; @@ -393,8 +407,9 @@ struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_llama_iswa : public llm_graph_context { - llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_llama4 : public llm_graph_context { + llm_build_llama4(const llama_model & model, const llm_graph_params & params); }; struct llm_build_maincoder : public llm_graph_context { @@ -481,7 +496,7 @@ struct llm_build_phi2 : public llm_graph_context { llm_build_phi2(const llama_model & model, const llm_graph_params & params); }; -template +template struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; @@ -687,12 +702,13 @@ struct llm_build_step35_iswa : public llm_graph_context { llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_t5_dec : public llm_graph_context { - llm_build_t5_dec(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_t5 : public llm_graph_context { + llm_build_t5(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_t5_enc : public llm_graph_context { - llm_build_t5_enc(const llama_model & model, const llm_graph_params & params); +struct llm_build_t5encoder : public llm_build_t5 { + llm_build_t5encoder(const llama_model & model, const llm_graph_params & params); }; struct llm_build_wavtokenizer_dec : public llm_graph_context { diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index 26020584c6d..5c6a1b5e1bc 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -2,7 +2,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -15,8 +14,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -37,14 +36,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll } // self attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - const size_t type_size = ggml_type_size(cur->type); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -64,7 +57,7 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index ce44a805f5c..8596bbb2024 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -1,10 +1,7 @@ #include "models.h" - - llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -38,25 +35,8 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & { cur = attn_norm; - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - if (hparams.f_clamp_kqv > 0.0f) { - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - } - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -76,7 +56,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 7af99174d16..dc07d43df58 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -65,40 +65,12 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const llama_model & model, int64_t n_embd_head, int il) { - // compute Q and K - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; @@ -107,9 +79,9 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -136,7 +108,10 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il, - router_logits); + router_logits, nullptr, + model.layers[il].ffn_up_exps_s, + nullptr, // no gate + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); if (model.layers[il].ffn_latent_up) { @@ -144,9 +119,9 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla } ggml_tensor * ffn_shexp = build_ffn(inp_emb, - model.layers[il].ffn_up_shexp, NULL, NULL, - NULL /* no gate */ , NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 34aa6fa5ec4..054b16fe0ef 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -31,27 +31,8 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -70,7 +51,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index 2fdf4a3692f..da68024a34d 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -2,7 +2,6 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -27,17 +26,8 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -57,7 +47,7 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 26f4b6ee628..a9974025f07 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -30,27 +30,8 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 5076359e3f9..308d2a600c2 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -89,7 +89,7 @@ llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 83a56a0b3b6..ed46a00ef90 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -68,7 +68,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe-iswa.cpp index 403f130bc41..50992b8d506 100644 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ b/examples/talk-llama/models/openai-moe-iswa.cpp @@ -28,27 +28,8 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_rot, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -67,7 +48,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 5df6fe3e3ce..514ac33517f 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -73,7 +73,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ cb(Qcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 48c01efe368..a5874b6dee7 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -30,30 +30,8 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - // if (model.layers[il].bq) { - // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - // cb(Qcur, "Qcur", il); - // } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - // if (model.layers[il].bk) { - // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - // cb(Kcur, "Kcur", il); - // } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - // if (model.layers[il].bv) { - // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - // cb(Vcur, "Vcur", il); - // } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +50,7 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 340455c2d5f..56cb1d94c5f 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -35,27 +35,8 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -74,7 +55,7 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embedded.cpp index 1cf0938e68f..53464f21d22 100644 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ b/examples/talk-llama/models/pangu-embedded.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -31,21 +30,8 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co // self attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -63,7 +49,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 32d40d71fb7..0fb3ffa2e63 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -30,29 +28,8 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -74,7 +51,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 3d11a9459c4..39af285d3c5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -3,7 +3,6 @@ template llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -39,27 +38,8 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ LLM_NORM_RMS, il); cb(attn_norm_output, "attn_norm", il); - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } - else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -80,7 +60,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ cb(Qcur, "Qcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index b7a71211042..4d5c84506c2 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -30,18 +30,8 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +50,7 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index f02acbc1869..b6142daebd9 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -71,6 +71,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, residual); cb(cur, "ffn_residual", il); + // input for next layer inpL = cur; } @@ -140,7 +141,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv ext_factor, attn_factor, beta_fast, beta_slow); cur = build_attn(inp, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); } diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 32af6e04663..67844c09f24 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -73,7 +73,7 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr const float attn_scale = 1.0f / sqrtf(float(head_dim_q)); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il); cb(cur, "attn_out", il); @@ -109,6 +109,8 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr cur = build_cvec(cur, il); cb(cur, "l_out", il); + + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index bcb651ce543..abce6b34d04 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -120,7 +120,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 7390f1320bf..44e75d87437 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -28,15 +27,8 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -56,7 +48,7 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 58c10622508..2892dd75087 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -30,30 +30,8 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +50,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 60761789dc9..5f0a6861b68 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -30,27 +30,8 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 9004bab9db1..da7937c7667 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -33,21 +33,8 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +53,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 52081668477..883dd5f9a90 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -30,18 +30,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,11 +56,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 3108bf331ac..87790f08e4e 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -64,6 +64,9 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_ffn", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -176,7 +179,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -222,9 +225,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); - alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -266,7 +270,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -342,7 +346,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 165e2412e56..7dc6a23c751 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -64,6 +64,9 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -176,7 +179,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -222,9 +225,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); - alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -266,7 +270,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -342,7 +346,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index dba46618ff2..16bedba994d 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -30,18 +30,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,11 +56,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cc479dd075c..1beda70b7cf 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -56,6 +56,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -154,7 +157,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -169,7 +172,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -351,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -411,19 +414,19 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads; - // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back - ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); - ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + // repeat interleave: reshape to (repeat part, 1, remaining part...), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); // Repeat along the third dimension (the new dimension with size 1) ggml_tensor * q_repeated = - ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); ggml_tensor * k_repeated = - ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); // Reshape back to merge the head and repeat dimensions - // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] - // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + // From [head_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs] + // Back to [head_dim, repeat_factor * num_k_heads, n_seq_tokens, n_seqs] q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); } @@ -442,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp index 195daea66c9..29ee8278a4d 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vl-moe.cpp @@ -36,18 +36,8 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -72,7 +62,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index bbd5f42ba5b..faa5f2ef3c8 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -36,18 +36,8 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -72,7 +62,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index 140700d9e2d..398eb368db0 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -24,25 +24,15 @@ llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index c8e1f43400f..a917c19f25a 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -32,18 +32,8 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -68,7 +58,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 15453fbf50f..032b219d6cb 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -8,7 +8,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 5caf6553dfe..16ffa6901b9 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -9,7 +9,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ggml_tensor * v_first = nullptr; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index a4d0b75d846..6db8d9781fe 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -32,27 +32,8 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -71,7 +52,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index e2155aacef4..55d09ec325d 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -45,18 +45,8 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, // self_attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, @@ -69,7 +59,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cb(Kcur, "Kcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -101,6 +91,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cur = ffn_out; cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index e267fd8f32f..83636dbf546 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -34,27 +34,8 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -74,7 +55,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index ff5aced93b3..9c19abd8835 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -30,30 +30,8 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, @@ -87,7 +65,7 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 941cee98219..cf9fe95c35b 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -2,7 +2,6 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -33,22 +32,11 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index a5965aceb3b..b6d4d5aac1a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -30,27 +30,8 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp index 176209cd93e..86aa98909e7 100644 --- a/examples/talk-llama/models/step35-iswa.cpp +++ b/examples/talk-llama/models/step35-iswa.cpp @@ -68,7 +68,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); ggml_tensor * attn_out = build_attn(inp_attn, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); // head-wise attention gate: sigmoid(g_proj(x)) in torch @@ -92,7 +92,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll } // output projection - cur = build_lora_mm(model.layers[il].wo, attn_out); + cur = build_lora_mm(model.layers[il].wo, attn_out, model.layers[il].wo_s); cb(cur, "attn_proj", il); } @@ -145,9 +145,11 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/t5-enc.cpp b/examples/talk-llama/models/t5-enc.cpp deleted file mode 100644 index 395dfb51042..00000000000 --- a/examples/talk-llama/models/t5-enc.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "models.h" - -llm_build_t5_enc::llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v(); - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); - - auto * inp_attn = build_attn_inp_no_cache(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; - ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - - cur = build_attn(inp_attn, - model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); - cb(cur, "kqv_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = build_ffn(cur, - model.layers[il].ffn_up_enc, NULL, NULL, - model.layers[il].ffn_gate_enc, NULL, NULL, - model.layers[il].ffn_down_enc, NULL, NULL, - NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - il); - cb(cur, "ffn_out", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - cb(cur, "result_embd", -1); - - cur = build_norm(cur, - model.output_norm_enc, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/t5-dec.cpp b/examples/talk-llama/models/t5.cpp similarity index 64% rename from examples/talk-llama/models/t5-dec.cpp rename to examples/talk-llama/models/t5.cpp index 8ca8372bd4c..9f9dfef4012 100644 --- a/examples/talk-llama/models/t5-dec.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template <> +llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -34,24 +35,13 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); cur = build_attn(inp_attn_self, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -82,7 +72,7 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); cur = build_attn(inp_attn_cross, - model.layers[il].wo_cross, nullptr, + model.layers[il].wo_cross, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -164,3 +154,99 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } + +template <> +llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); + + cur = build_attn(inp_attn, + model.layers[il].wo_enc, nullptr, nullptr, + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5encoder.cpp b/examples/talk-llama/models/t5encoder.cpp new file mode 100644 index 00000000000..5c1f9eb4030 --- /dev/null +++ b/examples/talk-llama/models/t5encoder.cpp @@ -0,0 +1,3 @@ +#include "models.h" + +llm_build_t5encoder::llm_build_t5encoder(const llama_model & model, const llm_graph_params & params) : llm_build_t5(model, params) {} diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index 537a0d41248..a7776d9cdc9 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -93,7 +93,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model cur = build_norm(cur, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); + LLM_NORM, 0); cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 3a8dfafcceb..53085ec80f6 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -28,18 +28,8 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -58,7 +48,7 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index 122c8ca04a5..dc13e53f09f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -470,6 +470,141 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return bpe_offsets; } +// Qwen2 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +static std::vector unicode_regex_split_custom_qwen2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + //if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?\p{L}+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters + pos++; + while (_get_flags(pos).is_letter) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -753,6 +888,35 @@ static std::vector unicode_regex_split_custom_afmoe(const std::string & return bpe_offsets; } +// regex: [^\n]+|[\n]+ +// splits text into runs of non-newline characters and runs of newline characters +static std::vector unicode_regex_split_custom_newlines(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + size_t pos = offset_ini; + while (pos < offset_end) { + const bool is_newline = (cpts[pos] == '\n'); + const size_t run_start = pos; + while (pos < offset_end && (cpts[pos] == '\n') == is_newline) { + pos++; + } + bpe_offsets.push_back(pos - run_start); + } + } + + return bpe_offsets; +} + static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -761,14 +925,18 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { - bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "[^\\n]+|[\\n]+") { + bpe_offsets = unicode_regex_split_custom_newlines(text, offsets); } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { // tiny_aya digit grouping pattern from tokenizer.json: // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} @@ -912,7 +1080,7 @@ bool unicode_cpt_is_han(uint32_t cpt) { return false; } -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs, bool byte_encode) { // unicode categories static const std::map k_ucat_enum = { { "\\p{N}", unicode_cpt_flags::NUMBER }, @@ -1099,5 +1267,9 @@ std::vector unicode_regex_split(const std::string & text, const std start += offset; } - return unicode_byte_encoding_process(bpe_words); + if (byte_encode) { + return unicode_byte_encoding_process(bpe_words); + } + + return bpe_words; } diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h index 5bd1362ff41..600ab9216b9 100644 --- a/examples/talk-llama/unicode.h +++ b/examples/talk-llama/unicode.h @@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt); bool unicode_cpt_is_han(uint32_t cpt); -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs, bool byte_encode = true); From d537f54999081f59a773e283d7f8cfc06df63664 Mon Sep 17 00:00:00 2001 From: Mariusz Reichert Date: Mon, 4 May 2026 17:28:31 +0200 Subject: [PATCH 244/249] fix(cmake/coreml): join whisper-targets export set; PRIVATE include dir Two latent bugs surfaced together when whisper.cpp is built with -DWHISPER_COREML=ON, both reproducible at CMake configure time: 1. install(TARGETS whisper.coreml) did not join the whisper-targets export set. Since whisper PRIVATE-links to whisper.coreml and is itself in whisper-targets, CMake refuses to generate with install(EXPORT "whisper-targets" ...) includes target "whisper" which requires target "whisper.coreml" that is not in any export set. Add EXPORT whisper-targets to the install (must come before LIBRARY in CMake's install(TARGETS ...) signature). 2. Once whisper.coreml is in the export set, its PUBLIC include dirs are validated against the install interface. The current "." include dir is a raw source-tree path with no $/$ guards and CMake refuses with INTERFACE_INCLUDE_DIRECTORIES property contains path "..." which is prefixed in the source directory. The headers under coreml/ are internal implementation details only consumed by whisper.cpp (in the same directory), so the correct fix is to mark them PRIVATE rather than wrapping them in install/build generator expressions. Verified locally with -DWHISPER_COREML=ON -DGGML_METAL=ON: configure clean, whisper.coreml + libwhisper.dylib build end-to-end. This unblocks the ios-xcode-build CI job on PR #12. QVAC-18300 Co-authored-by: Cursor --- src/CMakeLists.txt | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f360411d704..3a09c7b9157 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,8 +67,12 @@ if (WHISPER_COREML) include(DefaultTargetOptions) - target_include_directories(${TARGET} PUBLIC - . + # PRIVATE (not PUBLIC) so the raw source-tree include path doesn't + # propagate through the install/export interface. whisper.coreml is an + # internal helper that only whisper.cpp (in the same directory) consumes; + # external users of `whisper` never need to see these headers. + target_include_directories(${TARGET} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} ) target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK}) @@ -79,7 +83,12 @@ if (WHISPER_COREML) ) set_target_properties(${TARGET} PROPERTIES FOLDER "libs") - install(TARGETS ${TARGET} LIBRARY) + # Join the whisper-targets export set so that consumers' install(EXPORT) + # works when whisper PRIVATE-links to whisper.coreml. Without this the + # iOS / xcframework build fails at CMake generate time with + # install(EXPORT "whisper-targets" ...) includes target "whisper" which + # requires target "whisper.coreml" that is not in any export set. + install(TARGETS ${TARGET} EXPORT whisper-targets LIBRARY) endif() if (WHISPER_OPENVINO) From 1318aee92eb807b32ff0419bd431cf0dbd2128b3 Mon Sep 17 00:00:00 2001 From: Mariusz Reichert Date: Mon, 4 May 2026 22:10:23 +0200 Subject: [PATCH 245/249] fix(bindings/java): sync WhisperFullParams JNA layout with whisper.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bindings-java tests testGetDefaultFullParams_Greedy / testGetDefaultFullParams_BeamSearch on PR #12 fail with expected: <5> but was: <0> (greedy.best_of) expected: <5> but was: <-1> (beam_search.beam_size) while whisper_full_default_params() still returns 5 for both — the actual transcription test (testFullTranscribe) produces correct text. Diagnosis: the Java JNA WhisperFullParams Structure is missing fields that exist in the C whisper_full_params struct, so JNA computes wrong offsets and reads garbage at greedy.best_of / beam_search.beam_size. Specifically the Java layout was missing: 1. int32_t seed — added by tetherto's local seed patch between no_speech_thold and greedy (include/whisper.h:553). This single omission shifts every subsequent field by 4 bytes and is the proximate cause of both failing assertions. 2. bool vad — added by upstream 3. const char * vad_model_path 4. whisper_vad_params vad_params (struct) Fix: * New WhisperVadParams.java JNA Structure mirroring whisper_vad_params {threshold, min_speech_duration_ms, min_silence_duration_ms, max_speech_duration_s, speech_pad_ms, samples_overlap}. * Add `public int seed`, `public CBool vad`, `public String vad_model_path`, `public WhisperVadParams vad_params` fields and thread them into getFieldOrder() at the matching positions. Field order in WhisperFullParams.getFieldOrder() now matches the C struct in include/whisper.h field-for-field, so JNA-computed offsets agree with the native side. QVAC-18300 Co-authored-by: Cursor --- .../whispercpp/params/WhisperFullParams.java | 41 ++++++++++--- .../whispercpp/params/WhisperVadParams.java | 61 +++++++++++++++++++ 2 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 76ce80fb4cc..e1d5de27c65 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -224,6 +224,16 @@ public void suppressNonSpeechTokens(boolean enable) { /** No speech threshold. */ public float no_speech_thold; + /** + * RNG seed for reproducible sampling (when temperature > 0). + * Each decoder uses {@code seed + decoder_index} so concurrent decoders get + * unique seeds. Maps to the {@code seed} field added at + * {@code include/whisper.h:553}; without this field declared here the + * subsequent {@code greedy} / {@code beam_search} struct offsets shift by + * 4 bytes and JNA reads garbage from the C-side defaults. + */ + public int seed; + /** Greedy decoding parameters. */ public GreedyParams greedy; @@ -331,6 +341,21 @@ public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) { public long i_start_rule; public float grammar_penalty; + // Voice Activity Detection (VAD) params -- added by upstream after v1.8.4. + // Without these three fields declared here the C struct's tail is missing + // from the JNA layout, which is fine for read-only callers but corrupts + // the trailing memory whenever a Java caller passes WhisperFullParams + // back into the C ABI (e.g. whisper_full). + + /** Enable VAD pre-filtering inside whisper_full. (default = false) */ + public CBool vad; + + /** Path to the Silero VAD model (only used when {@link #vad} is true). */ + public String vad_model_path; + + /** VAD tuning knobs, mirrors {@code whisper_vad_params}. */ + public WhisperVadParams vad_params; + @Override protected List getFieldOrder() { return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", @@ -343,13 +368,15 @@ protected List getFieldOrder() { "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", - "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", - "beam_search", "new_segment_callback", "new_segment_callback_user_data", - "progress_callback", "progress_callback_user_data", - "encoder_begin_callback", "encoder_begin_callback_user_data", - "abort_callback", "abort_callback_user_data", - "logits_filter_callback", "logits_filter_callback_user_data", - "grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty"); + "entropy_thold", "logprob_thold", "no_speech_thold", "seed", + "greedy", "beam_search", "new_segment_callback", + "new_segment_callback_user_data", "progress_callback", + "progress_callback_user_data", "encoder_begin_callback", + "encoder_begin_callback_user_data", "abort_callback", + "abort_callback_user_data", "logits_filter_callback", + "logits_filter_callback_user_data", "grammar_rules", + "n_grammar_rules", "i_start_rule", "grammar_penalty", + "vad", "vad_model_path", "vad_params"); } public static class ByValue extends WhisperFullParams implements Structure.ByValue { diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java new file mode 100644 index 00000000000..d7487edf811 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperVadParams.java @@ -0,0 +1,61 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.Pointer; +import com.sun.jna.Structure; + +import java.util.Arrays; +import java.util.List; + +/** + * Voice Activity Detection (VAD) parameters. + * Mirrors {@code struct whisper_vad_params} in include/whisper.h. + */ +public class WhisperVadParams extends Structure { + + public WhisperVadParams() { + super(); + } + + public WhisperVadParams(Pointer p) { + super(p); + } + + /** Probability threshold to consider as speech. */ + public float threshold; + + /** Min duration for a valid speech segment. */ + public int min_speech_duration_ms; + + /** Min silence duration to consider speech as ended. */ + public int min_silence_duration_ms; + + /** Max duration of a speech segment before forcing a new segment. */ + public float max_speech_duration_s; + + /** Padding added before and after speech segments. */ + public int speech_pad_ms; + + /** Overlap in seconds when copying audio samples from speech segment. */ + public float samples_overlap; + + @Override + protected List getFieldOrder() { + return Arrays.asList( + "threshold", + "min_speech_duration_ms", + "min_silence_duration_ms", + "max_speech_duration_s", + "speech_pad_ms", + "samples_overlap"); + } + + public static class ByValue extends WhisperVadParams implements Structure.ByValue { + public ByValue() { + super(); + } + + public ByValue(Pointer p) { + super(p); + } + } +} From 47784b9e00dcf1068f334bb30a4b8e89f8875f52 Mon Sep 17 00:00:00 2001 From: Zbigniew Herman Date: Mon, 18 May 2026 16:35:37 +0200 Subject: [PATCH 246/249] test: cover whisper_vad streaming API added by upstream PR #3677 Upstream ggml-org/whisper.cpp PR #3677 added the streaming VAD entry points but shipped no test. Lock the public contract on the tetherto fork so regressions surface immediately: - whisper_vad_detect_speech idempotent (reset is implicit) - whisper_vad_reset_state restores LSTM state exactly - detect_speech == reset_state + detect_speech_no_reset - detect_speech_no_reset on contiguous halves == single-shot detect_speech (state carries across no-reset call boundary) Splits at a 512-sample boundary (Silero v6.2.0 window size) so no mid-stream zero padding is introduced. Uses the bundled silero VAD model and samples/jfk.wav; no whisper transcribe model needed. QVAC-18991 Co-authored-by: Cursor --- tests/CMakeLists.txt | 15 ++++ tests/test-vad-streaming.cpp | 133 +++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 tests/test-vad-streaming.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..6b80b023ffb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -110,3 +110,18 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Streaming VAD test pins the upstream PR #3677 contract: +# whisper_vad_detect_speech_no_reset + whisper_vad_reset_state must +# combine to yield exactly the same per-chunk probabilities as the +# original whisper_vad_detect_speech call. Uses only the silero VAD +# model (no whisper transcribe model needed). +set(VAD_TEST test-vad-streaming) +add_executable(${VAD_TEST} ${VAD_TEST}.cpp) +target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${VAD_TEST} PRIVATE common) +target_compile_definitions(${VAD_TEST} PRIVATE + VAD_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-silero-v6.2.0-ggml.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) +set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit;vad") diff --git a/tests/test-vad-streaming.cpp b/tests/test-vad-streaming.cpp new file mode 100644 index 00000000000..4c4d092eefa --- /dev/null +++ b/tests/test-vad-streaming.cpp @@ -0,0 +1,133 @@ +// Regression coverage for the streaming VAD API added by upstream +// ggml-org/whisper.cpp PR #3677 (commit 166c20b4) and pulled into the +// tetherto fork via QVAC-18991. The upstream change did not ship a test, +// so we add one here to lock the contract: +// +// 1. whisper_vad_detect_speech is idempotent (reset is implicit). +// 2. whisper_vad_reset_state restores the LSTM state, so a no-reset +// run after a reset produces the same probs as the very first run. +// 3. Two contiguous whisper_vad_detect_speech_no_reset calls on +// adjacent halves of the input produce the same per-chunk probs +// as a single whisper_vad_detect_speech(full_input) — i.e. the +// LSTM state actually carries across the boundary (within the +// tolerance of running the same graph through two scheduler +// activations). +// +// Split is performed at a multiple of vctx->n_window so that the second +// half starts cleanly on a chunk boundary and no zero-padding is +// introduced mid-stream that would diverge from the single-shot run. + +#include "whisper.h" +#include "common-whisper.h" + +#include +#include +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +static std::vector snapshot_probs(struct whisper_vad_context * vctx) { + const int n = whisper_vad_n_probs(vctx); + const float * p = whisper_vad_probs(vctx); + return std::vector(p, p + n); +} + +static void assert_probs_near(const std::vector & a, + const std::vector & b, + float tol, + const char * label) { + assert(a.size() == b.size()); + float worst = 0.0f; + for (size_t i = 0; i < a.size(); ++i) { + const float d = std::fabs(a[i] - b[i]); + if (d > worst) worst = d; + } + printf("%s: max |diff| = %.6f over %zu probs (tol = %.6f)\n", label, worst, a.size(), tol); + assert(worst <= tol); +} + +int main() { + const std::string vad_model_path = VAD_MODEL_PATH; + const std::string sample_path = SAMPLE_PATH; + + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + + struct whisper_vad_context_params ctx_params = whisper_vad_default_context_params(); + struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params( + vad_model_path.c_str(), ctx_params); + assert(vctx != nullptr); + + // --- Test 1: detect_speech is idempotent (implicit reset). ----------- + assert(whisper_vad_detect_speech(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_first = snapshot_probs(vctx); + + assert(whisper_vad_detect_speech(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_second = snapshot_probs(vctx); + + assert_probs_near(probs_first, probs_second, 0.0f, + "Test1 detect_speech idempotent"); + + // --- Test 2: reset_state restores initial LSTM state. ---------------- + whisper_vad_reset_state(vctx); + assert(whisper_vad_detect_speech_no_reset(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_no_reset_a = snapshot_probs(vctx); + + whisper_vad_reset_state(vctx); + assert(whisper_vad_detect_speech_no_reset(vctx, pcmf32.data(), (int)pcmf32.size())); + const auto probs_no_reset_b = snapshot_probs(vctx); + + assert_probs_near(probs_no_reset_a, probs_no_reset_b, 0.0f, + "Test2 reset_state restores LSTM"); + + // detect_speech (which resets internally) must also match the + // reset+no_reset sequence — proves they share identical semantics + // when starting from a clean state. + assert_probs_near(probs_first, probs_no_reset_a, 0.0f, + "Test2b detect_speech == reset+no_reset"); + + // --- Test 3: streaming carries LSTM state across calls. -------------- + // Split exactly on a chunk boundary so we don't introduce mid-stream + // zero padding. The Silero v6.2.0 VAD model fixture uses a fixed + // 512-sample window at 16 kHz (32 ms); the chunk count is + // ceil(n_samples / 512), with the last chunk zero-padded if short. + constexpr int kSileroWindow = 512; + const int total_chunks = (int)probs_first.size(); + assert(total_chunks == ((int)pcmf32.size() + kSileroWindow - 1) / kSileroWindow); + const int half_chunks = total_chunks / 2; + assert(half_chunks >= 1 && total_chunks - half_chunks >= 1); + const int split_idx = half_chunks * kSileroWindow; + assert(split_idx > 0 && split_idx < (int)pcmf32.size()); + + whisper_vad_reset_state(vctx); + assert(whisper_vad_detect_speech_no_reset(vctx, pcmf32.data(), split_idx)); + const auto probs_part1 = snapshot_probs(vctx); + assert((int)probs_part1.size() == half_chunks); + + assert(whisper_vad_detect_speech_no_reset( + vctx, pcmf32.data() + split_idx, (int)pcmf32.size() - split_idx)); + const auto probs_part2 = snapshot_probs(vctx); + assert((int)probs_part2.size() == total_chunks - half_chunks); + + // Concatenate part1 + part2 and compare against single-shot probs. + std::vector probs_stitched; + probs_stitched.reserve(total_chunks); + probs_stitched.insert(probs_stitched.end(), probs_part1.begin(), probs_part1.end()); + probs_stitched.insert(probs_stitched.end(), probs_part2.begin(), probs_part2.end()); + + // Float-equality is the contract: the per-chunk graph is the same + // and the LSTM state is preserved exactly across the no_reset call + // boundary. If a future refactor introduces tiny numerical drift + // across scheduler resets, bump the tolerance — but never silently. + assert_probs_near(probs_first, probs_stitched, 0.0f, + "Test3 streaming == single-shot"); + + whisper_vad_free(vctx); + return 0; +} From eb63b2b710417a8955b986748858b21e735f50f8 Mon Sep 17 00:00:00 2001 From: Zbigniew Herman Date: Tue, 19 May 2026 18:12:44 +0200 Subject: [PATCH 247/249] ggml : allow GGML_BACKEND_DL with a static core (QVAC-18993) When vcpkg's arm64-android triplet forces VCPKG_LIBRARY_LINKAGE=static (=> BUILD_SHARED_LIBS=OFF) the bundled ggml unconditionally aborts at CMake configure time with: FATAL_ERROR: GGML_BACKEND_DL requires BUILD_SHARED_LIBS even though the static dispatcher + MODULE backend .so files combo actually works: the dispatcher just needs PIC (already gated by the same BUILD_SHARED_LIBS branch below) so it can be dlsym'd from the MODULE-built backend libraries. Three guards changed from `BUILD_SHARED_LIBS` to `BUILD_SHARED_LIBS OR GGML_BACKEND_DL` (FATAL_ERROR removed, GGML_BACKEND_BUILD/SHARED defs on each backend, PIC + GGML_BUILD on the core targets), so the Android dynamic-backend recipe used by qvac-registry-vcpkg's whisper-cpp port (-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_CPU_REPACK=ON) now configures. Mirrors the equivalent change carried in qvac-ext-ggml@speech for the parallel speech-stack consumers (parakeet-cpp / tts-cpp). Validated by an NDK r29 cross-compile of bundled ggml + whisper.cpp with the flags above (all 7 per-arch libggml-cpu-android_armv*_*.so produced clean). Co-authored-by: Cursor --- ggml/src/CMakeLists.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 788ad1dba7a..ecd054633e8 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -185,9 +185,8 @@ endif() # ggml -if (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS) - message(FATAL_ERROR "GGML_BACKEND_DL requires BUILD_SHARED_LIBS") -endif() +# QVAC-18993: GGML_BACKEND_DL works with a static core when PIC is +# enabled below (mirrors the speech-stack patch in qvac-ext-ggml). add_library(ggml-base ../include/ggml.h @@ -266,7 +265,7 @@ function(ggml_add_backend_library backend) target_link_libraries(${backend} PRIVATE ggml-base) target_include_directories(${backend} PRIVATE ..) - if (${BUILD_SHARED_LIBS}) + if (BUILD_SHARED_LIBS OR GGML_BACKEND_DL) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD) target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED) endif() @@ -484,7 +483,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "visionOS") target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE) endif() -if (BUILD_SHARED_LIBS) +if (BUILD_SHARED_LIBS OR GGML_BACKEND_DL) foreach (target ggml-base ggml) set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(${target} PRIVATE GGML_BUILD) From 3683de4b7ce479d3d2a96c4748fbf348982827e3 Mon Sep 17 00:00:00 2001 From: Zbigniew Herman Date: Tue, 19 May 2026 18:13:03 +0200 Subject: [PATCH 248/249] ggml-backend : android per-arch CPU variant dlopen fallback (QVAC-18993) Android app packaging keeps native libraries compressed inside the APK with no on-disk directory to scan (AGP's `useLegacyPackaging=false` default since 3.6). The directory-iterator pass in `ggml_backend_load_best` therefore finds nothing on Android and the existing per-search_path `fs::exists` filename fallback also returns false, leaving the loader to return nullptr and the consumer to fail `init_cpu_backend()`. For backends that ship as a single library (Vulkan / OpenCL / ...) the bare `libggml-.so` filename is enough to resolve via Android's in-APK linker lookup, but with `GGML_CPU_ALL_VARIANTS=ON` (the qvac-registry-vcpkg whisper-cpp port default for Android per QVAC-18993) the CPU backend ships only as per-arch variants -- there is no plain `libggml-cpu.so` for the fallback to compose, so the CPU backend silently never registers. Enumerate the known per-arch Android variants as additional candidate names for the "cpu" backend and run each through the standard `ggml_backend_score` selection so the device's HWCAP picks the right tier (armv8.0 baseline through armv9.2_2; matches the variants list emitted by `ggml_add_cpu_backend_variant()` in ggml/src/CMakeLists.txt around lines 410-416). Fast-path for the size-1 candidate case (every backend on every non-Android platform, plus Vulkan / OpenCL / Metal / ... on Android): single load_backend call, identical cost to the previous code path. The score-then-reload loop only runs when there's an actual choice to make. Mirrors qvac-ext-ggml@speech commit 9562ed04 ("ggml-backend: android per-arch CPU variant dlopen fallback", @GustavoA1604, PR #11). Carried here as a separate commit on top of the v1.8.4.3 upstream-sync branch so the whisper-cpp vcpkg port can ship Android dynamic-backend mode without a port-level patch (`patches/0002-...`). Validated by an NDK r29 cross-compile of bundled ggml + whisper.cpp with -DGGML_BACKEND_DL=ON -DBUILD_SHARED_LIBS=OFF -DGGML_CPU_ALL_VARIANTS=ON -DGGML_CPU_REPACK=ON: - all 7 per-arch libggml-cpu-android_armv*_*.so produced clean; - `strings ggml-backend-reg.cpp.o | grep cpu-android_armv` confirms the __ANDROID__ block compiles into the dispatcher object. Co-authored-by: Cursor --- ggml/src/ggml-backend-reg.cpp | 91 +++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 8165ae2c8bb..f3c52baff82 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -546,6 +546,97 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } } } +#ifdef __ANDROID__ + // QVAC-18993: Android app packaging keeps native libraries + // compressed inside the APK with no on-disk directory to scan + // (the default since AGP 3.6's `useLegacyPackaging=false`). + // The directory-iterator loop above therefore finds nothing on + // Android and the per-search_path `fs::exists` fallback also + // returns false. Try the bare filename instead and let Android's + // dynamic linker resolve it via the in-APK lookup. + // + // For backends that ship as a single library (Vulkan / OpenCL / + // ...) the bare `libggml-.so` filename is enough. + // For the CPU backend with `GGML_CPU_ALL_VARIANTS=ON` there is + // no plain `libggml-cpu.so`, only the per-arch + // `libggml-cpu-android_armv*_*.so` files, so we also + // try each known per-arch variant and let `ggml_backend_score` + // pick the highest-scoring one the device's HWCAP supports. + // + // Mirrors qvac-ext-ggml@speech commit 9562ed04 ("ggml-backend: + // android per-arch CPU variant dlopen fallback") so APK + // consumers of the bundled-ggml whisper-cpp port get the same + // Android-correct CPU init as ggml-speech consumers (parakeet, + // chatterbox, ...). + // + // TODO: keep this list in sync with the + // `ggml_add_cpu_backend_variant(android_armv*_*)` calls in + // ggml/src/CMakeLists.txt. New tiers added there must be + // appended here as well, or devices on the new tier will + // silently fall back to a lower one (perf hit). + std::vector candidate_names = { name_path }; + if (strcmp(name, "cpu") == 0) { + candidate_names.emplace_back("cpu-android_armv8.0_1"); + candidate_names.emplace_back("cpu-android_armv8.2_1"); + candidate_names.emplace_back("cpu-android_armv8.2_2"); + candidate_names.emplace_back("cpu-android_armv8.6_1"); + candidate_names.emplace_back("cpu-android_armv9.0_1"); + candidate_names.emplace_back("cpu-android_armv9.2_1"); + candidate_names.emplace_back("cpu-android_armv9.2_2"); + } + + if (candidate_names.size() == 1) { + // Fast path: Vulkan / OpenCL / Metal / ... single-shot dlopen. + fs::path filename = backend_filename_prefix().native() + + candidate_names[0].native() + + backend_filename_extension().native(); + if (auto reg = get_reg().load_backend(filename, silent)) { + return reg; + } + return nullptr; + } + + // Multi-candidate (Android CPU today): iterate worst -> best with + // a synthetic per-index offset on top of the runtime score so + // that on a device that accepts every variant (e.g. armv9.2 + // phone) the highest tier wins on tie, while a device that + // legitimately supports only the baseline still picks armv8.0_1. + for (size_t idx = 0; idx < candidate_names.size(); ++idx) { + fs::path filename = backend_filename_prefix().native() + + candidate_names[idx].native() + + backend_filename_extension().native(); + dl_handle_ptr handle { dl_load_library(filename) }; + if (!handle) { + if (!silent) { + GGML_LOG_DEBUG("%s: dlopen(%s) failed: %s\n", __func__, + path_str(filename).c_str(), dl_error()); + } + continue; + } + auto score_fn = (ggml_backend_score_t) + dl_get_sym(handle.get(), "ggml_backend_score"); + int s = 1; // base score for backends without ggml_backend_score + if (score_fn) { + s = score_fn(); + if (s == 0) { + continue; + } + } + s += static_cast(idx); +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, + path_str(filename).c_str(), s); +#endif + if (s > best_score) { + best_score = s; + best_path = filename; + } + } + + if (best_score > 0) { + return get_reg().load_backend(best_path, silent); + } +#endif // __ANDROID__ return nullptr; } From 14620c8857fc289313a3b0a82c9ce69accaa046d Mon Sep 17 00:00:00 2001 From: Zbigniew Herman Date: Tue, 19 May 2026 18:13:13 +0200 Subject: [PATCH 249/249] tts-cpp : add missing include in chatterbox_tts.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit g_s3gen_cache_refcount (line ~189) is declared as `static std::atomic` and is later used with `.fetch_add()`, `.fetch_sub()`, `.store()`, but the translation unit only pulls in ggml/cstring/stl headers — never `` directly. libstdc++ happens to expose `std::atomic` transitively via `` on most hosts so the build appears clean, but on the ggml-speech sync path where header transitivity changes (the qvac-ext-ggml@speech merge of ggml-org v0.10.2 cuts a few of those transitive paths) the translation unit fails with: chatterbox_tts.cpp:189: variable `std::atomic g_s3gen_cache_refcount' has initializer but incomplete type Reproduces on the pre-merge speech HEAD too -- it was previously hidden by header transitivity. Add `#include ` explicitly. Verified by a clean rebuild of tts-cpp against an `-DBUILD_SHARED_LIBS=ON` install of qvac-ext-ggml@speech HEAD (45dbdecd, day-2 ggml-speech). Co-authored-by: Cursor --- tts-cpp/src/chatterbox_tts.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tts-cpp/src/chatterbox_tts.cpp b/tts-cpp/src/chatterbox_tts.cpp index edd762a7285..603cbc74f3f 100644 --- a/tts-cpp/src/chatterbox_tts.cpp +++ b/tts-cpp/src/chatterbox_tts.cpp @@ -43,6 +43,7 @@ #endif #include +#include #include #include #include